From 813fe24b88c3595fc474a447e6b4bbe86695b031 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sun, 25 Jun 2023 02:11:46 -0700 Subject: [PATCH 01/32] chore: update crate version to v0.7.0 --- CHANGELOG.md | 2 ++ Cargo.toml | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f83519..355d33f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +## 0.7.0 (TBD) + ## 0.6.0 (2023-06-25) * [BREAKING] Added support for recording capabilities for `MerkleStore` (#162). diff --git a/Cargo.toml b/Cargo.toml index 8c9a3b5..7bffb94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "miden-crypto" -version = "0.6.0" +version = "0.7.0" description = "Miden Cryptographic primitives" authors = ["miden contributors"] readme = "README.md" license = "MIT" repository = "https://github.com/0xPolygonMiden/crypto" -documentation = "https://docs.rs/miden-crypto/0.6.0" +documentation = "https://docs.rs/miden-crypto/0.7.0" categories = ["cryptography", "no-std"] keywords = ["miden", "crypto", "hash", "merkle"] edition = "2021" From 08aec4443ccfb2dfab4dc193046ccec6551de687 Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Thu, 6 Jul 2023 00:19:03 +0300 Subject: [PATCH 02/32] Enhancement of the Partial Merkle Tree (#163) feat: implement additional functionality for the PartialMerkleTree --- src/merkle/index.rs | 17 ++++ src/merkle/partial_mt/mod.rs | 176 +++++++++++++++++++++++++++++---- src/merkle/partial_mt/tests.rs | 172 +++++++++++++++++++++++++++++--- src/merkle/store/mod.rs | 8 +- src/utils/mod.rs | 4 +- 5 files changed, 342 insertions(+), 35 deletions(-) diff --git a/src/merkle/index.rs b/src/merkle/index.rs index f17216f..3a79ac0 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -1,4 +1,5 @@ use super::{Felt, MerkleError, RpoDigest, StarkField}; +use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use core::fmt::Display; // NODE INDEX @@ -161,6 +162,22 @@ impl Display for NodeIndex { } } +impl Serializable for NodeIndex { + fn write_into(&self, target: &mut W) { + target.write_u8(self.depth); + target.write_u64(self.value); + } +} + +impl Deserializable for NodeIndex { + fn read_from(source: &mut R) -> Result { + let depth = source.read_u8()?; + let value = source.read_u64()?; + NodeIndex::new(depth, value) + .map_err(|_| DeserializationError::InvalidValue("Invalid index".into())) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index 3558c9f..ef87516 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -1,7 +1,11 @@ use super::{ - BTreeMap, BTreeSet, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Vec, ZERO, + BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, + ValuePath, Vec, Word, ZERO, +}; +use crate::utils::{ + format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable, + DeserializationError, Serializable, }; -use crate::utils::{format, string::String, word_to_hex}; use core::fmt; #[cfg(test)] @@ -74,6 +78,92 @@ impl PartialMerkleTree { }) } + /// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided + /// entries. + /// + /// # Errors + /// Returns an error if: + /// - If the depth is 0 or is greater than 64. + /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}. + /// - The provided entries contain an insufficient set of nodes. + pub fn with_leaves(entries: R) -> Result + where + R: IntoIterator, + I: Iterator + ExactSizeIterator, + { + let mut layers: BTreeMap> = BTreeMap::new(); + let mut leaves = BTreeSet::new(); + let mut nodes = BTreeMap::new(); + + // add data to the leaves and nodes maps and also fill layers map, where the key is the + // depth of the node and value is its index. + for (node_index, hash) in entries.into_iter() { + leaves.insert(node_index); + nodes.insert(node_index, hash); + layers + .entry(node_index.depth()) + .and_modify(|layer_vec| layer_vec.push(node_index.value())) + .or_insert(vec![node_index.value()]); + } + + // check if the number of leaves can be accommodated by the tree's depth; we use a min + // depth of 63 because we consider passing in a vector of size 2^64 infeasible. + let max = (1_u64 << 63) as usize; + if layers.len() > max { + return Err(MerkleError::InvalidNumEntries(max, layers.len())); + } + + // Get maximum depth + let max_depth = *layers.keys().next_back().unwrap_or(&0); + + // fill layers without nodes with empty vector + for depth in 0..max_depth { + layers.entry(depth).or_insert(vec![]); + } + + let mut layer_iter = layers.into_values().rev(); + let mut parent_layer = layer_iter.next().unwrap(); + let mut current_layer; + + for depth in (1..max_depth + 1).rev() { + // set current_layer = parent_layer and parent_layer = layer_iter.next() + current_layer = layer_iter.next().unwrap(); + core::mem::swap(&mut current_layer, &mut parent_layer); + + for index_value in current_layer { + // get the parent node index + let parent_node = NodeIndex::new(depth - 1, index_value / 2)?; + + // Check if the parent hash was already calculated. In about half of the cases, we + // don't need to do anything. + if !parent_layer.contains(&parent_node.value()) { + // create current node index + let index = NodeIndex::new(depth, index_value)?; + + // get hash of the current node + let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?; + // get hash of the sibling node + let sibling = nodes + .get(&index.sibling()) + .ok_or(MerkleError::NodeNotInSet(index.sibling()))?; + // get parent hash + let parent = Rpo256::merge(&index.build_node(*node, *sibling)); + + // add index value of the calculated node to the parents layer + parent_layer.push(parent_node.value()); + // add index and hash to the nodes map + nodes.insert(parent_node, parent); + } + } + } + + Ok(PartialMerkleTree { + max_depth, + nodes, + leaves, + }) + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -101,7 +191,7 @@ impl PartialMerkleTree { } /// Returns a vector of paths from every leaf to the root. - pub fn paths(&self) -> Vec<(NodeIndex, ValuePath)> { + pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> { let mut paths = Vec::new(); self.leaves.iter().for_each(|&leaf| { paths.push(( @@ -160,6 +250,22 @@ impl PartialMerkleTree { }) } + /// Returns an iterator over the inner nodes of this Merkle tree. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index)); + inner_nodes.map(|(index, digest)| { + let left_hash = + self.nodes.get(&index.left_child()).expect("Failed to get left child hash"); + let right_hash = + self.nodes.get(&index.right_child()).expect("Failed to get right child hash"); + InnerNodeInfo { + value: *digest, + left: *left_hash, + right: *right_hash, + } + }) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -235,37 +341,37 @@ impl PartialMerkleTree { /// Updates value of the leaf at the specified index returning the old leaf value. /// + /// By default the specified index is assumed to belong to the deepest layer. If the considered + /// node does not belong to the tree, the first node on the way to the root will be changed. + /// /// This also recomputes all hashes between the leaf and the root, updating the root itself. /// /// # Errors /// Returns an error if: - /// - The depth of the specified node_index is greater than 64 or smaller than 1. - /// - The specified node index is not corresponding to the leaf. - pub fn update_leaf( - &mut self, - node_index: NodeIndex, - value: RpoDigest, - ) -> Result { - // check correctness of the depth and update it - Self::check_depth(node_index.depth())?; - self.update_depth(node_index.depth()); + /// - The specified index is greater than the maximum number of nodes on the deepest layer. + pub fn update_leaf(&mut self, index: u64, value: Word) -> Result { + let mut node_index = NodeIndex::new(self.max_depth(), index)?; - // insert NodeIndex to the leaves Set - self.leaves.insert(node_index); + // proceed to the leaf + for _ in 0..node_index.depth() { + if !self.leaves.contains(&node_index) { + node_index.move_up(); + } + } // add node value to the nodes Map let old_value = self .nodes - .insert(node_index, value) + .insert(node_index, value.into()) .ok_or(MerkleError::NodeNotInSet(node_index))?; // if the old value and new value are the same, there is nothing to update - if value == old_value { + if value == *old_value { return Ok(old_value); } let mut node_index = node_index; - let mut value = value; + let mut value = value.into(); for _ in 0..node_index.depth() { let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); value = Rpo256::merge(&node_index.build_node(value, *sibling)); @@ -327,3 +433,37 @@ impl PartialMerkleTree { Ok(()) } } + +// SERIALIZATION +// ================================================================================================ + +impl Serializable for PartialMerkleTree { + fn write_into(&self, target: &mut W) { + // write leaf nodes + target.write_u64(self.leaves.len() as u64); + for leaf_index in self.leaves.iter() { + leaf_index.write_into(target); + self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target); + } + } +} + +impl Deserializable for PartialMerkleTree { + fn read_from(source: &mut R) -> Result { + let leaves_len = source.read_u64()? as usize; + let mut leaf_nodes = Vec::with_capacity(leaves_len); + + // add leaf nodes to the vector + for _ in 0..leaves_len { + let index = NodeIndex::read_from(source)?; + let hash = RpoDigest::read_from(source)?; + leaf_nodes.push((index, hash)); + } + + let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| { + DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into()) + })?; + + Ok(pmt) + } +} diff --git a/src/merkle/partial_mt/tests.rs b/src/merkle/partial_mt/tests.rs index ed5281f..4e580d2 100644 --- a/src/merkle/partial_mt/tests.rs +++ b/src/merkle/partial_mt/tests.rs @@ -1,9 +1,9 @@ use super::{ super::{ - digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, - PartialMerkleTree, + digests_to_words, int_to_node, BTreeMap, DefaultMerkleStore as MerkleStore, MerkleTree, + NodeIndex, PartialMerkleTree, }, - RpoDigest, ValuePath, Vec, + Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Vec, }; // TEST DATA @@ -13,6 +13,7 @@ const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0); const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1); const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0); +const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1); const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2); const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3); @@ -50,6 +51,43 @@ const VALUES8: [RpoDigest; 8] = [ // NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis // (33). +/// Checks that creation of the PMT with `with_leaves()` constructor is working correctly. +#[test] +fn with_leaves() { + let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); + let expected_root = mt.root(); + + let leaf_nodes_vec = vec![ + (NODE20, mt.get_node(NODE20).unwrap()), + (NODE32, mt.get_node(NODE32).unwrap()), + (NODE33, mt.get_node(NODE33).unwrap()), + (NODE22, mt.get_node(NODE22).unwrap()), + (NODE23, mt.get_node(NODE23).unwrap()), + ]; + + let leaf_nodes: BTreeMap = leaf_nodes_vec.into_iter().collect(); + + let pmt = PartialMerkleTree::with_leaves(leaf_nodes).unwrap(); + + assert_eq!(expected_root, pmt.root()) +} + +/// Checks that `with_leaves()` function returns an error when using incomplete set of nodes. +#[test] +fn err_with_leaves() { + // NODE22 is missing + let leaf_nodes_vec = vec![ + (NODE20, int_to_node(20)), + (NODE32, int_to_node(32)), + (NODE33, int_to_node(33)), + (NODE23, int_to_node(23)), + ]; + + let leaf_nodes: BTreeMap = leaf_nodes_vec.into_iter().collect(); + + assert!(PartialMerkleTree::with_leaves(leaf_nodes).is_err()); +} + /// Checks that root returned by `root()` function is equal to the expected one. #[test] fn get_root() { @@ -61,7 +99,7 @@ fn get_root() { let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); - assert_eq!(pmt.root(), expected_root); + assert_eq!(expected_root, pmt.root()); } /// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a @@ -121,7 +159,7 @@ fn update_leaf() { let new_value32 = int_to_node(132); let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root; - pmt.update_leaf(NODE32, new_value32).unwrap(); + pmt.update_leaf(2, *new_value32).unwrap(); let actual_root = pmt.root(); assert_eq!(expected_root, actual_root); @@ -129,7 +167,15 @@ fn update_leaf() { let new_value20 = int_to_node(120); let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root; - pmt.update_leaf(NODE20, new_value20).unwrap(); + pmt.update_leaf(0, *new_value20).unwrap(); + let actual_root = pmt.root(); + + assert_eq!(expected_root, actual_root); + + let new_value11 = int_to_node(111); + let expected_root = ms.set_node(expected_root, NODE11, new_value11).unwrap().root; + + pmt.update_leaf(6, *new_value11).unwrap(); let actual_root = pmt.root(); assert_eq!(expected_root, actual_root); @@ -177,7 +223,7 @@ fn get_paths() { }) .collect(); - let actual_paths = pmt.paths(); + let actual_paths = pmt.to_paths(); assert_eq!(expected_paths, actual_paths); } @@ -247,6 +293,113 @@ fn leaves() { assert!(expected_leaves.eq(pmt.leaves())); } +/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones. +#[test] +fn test_inner_node_iterator() { + let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); + let expected_root = mt.root(); + + let ms = MerkleStore::from(&mt); + + let path33 = ms.get_path(expected_root, NODE33).unwrap(); + let path22 = ms.get_path(expected_root, NODE22).unwrap(); + + let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); + + // get actual inner nodes + let actual: Vec = pmt.inner_nodes().collect(); + + let expected_n00 = mt.root(); + let expected_n10 = mt.get_node(NODE10).unwrap(); + let expected_n11 = mt.get_node(NODE11).unwrap(); + let expected_n20 = mt.get_node(NODE20).unwrap(); + let expected_n21 = mt.get_node(NODE21).unwrap(); + let expected_n32 = mt.get_node(NODE32).unwrap(); + let expected_n33 = mt.get_node(NODE33).unwrap(); + + // create vector of the expected inner nodes + let mut expected = vec![ + InnerNodeInfo { + value: expected_n00, + left: expected_n10, + right: expected_n11, + }, + InnerNodeInfo { + value: expected_n10, + left: expected_n20, + right: expected_n21, + }, + InnerNodeInfo { + value: expected_n21, + left: expected_n32, + right: expected_n33, + }, + ]; + + assert_eq!(actual, expected); + + // add another path to the Partial Merkle Tree + pmt.add_path(2, path22.value, path22.path).unwrap(); + + // get new actual inner nodes + let actual: Vec = pmt.inner_nodes().collect(); + + let expected_n22 = mt.get_node(NODE22).unwrap(); + let expected_n23 = mt.get_node(NODE23).unwrap(); + + let info_11 = InnerNodeInfo { + value: expected_n11, + left: expected_n22, + right: expected_n23, + }; + + // add new inner node to the existing vertor + expected.insert(2, info_11); + + assert_eq!(actual, expected); +} + +/// Checks that serialization and deserialization implementations for the PMT are working +/// correctly. +#[test] +fn serialization() { + let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); + let expected_root = mt.root(); + + let ms = MerkleStore::from(&mt); + + let path33 = ms.get_path(expected_root, NODE33).unwrap(); + let path22 = ms.get_path(expected_root, NODE22).unwrap(); + + let pmt = PartialMerkleTree::with_paths([ + (3, path33.value, path33.path), + (2, path22.value, path22.path), + ]) + .unwrap(); + + let serialized_pmt = pmt.to_bytes(); + let deserialized_pmt = PartialMerkleTree::read_from_bytes(&serialized_pmt).unwrap(); + + assert_eq!(deserialized_pmt, pmt); +} + +/// Checks that deserialization fails with incorrect data. +#[test] +fn err_deserialization() { + let mut tree_bytes: Vec = vec![5]; + tree_bytes.append(&mut NODE20.to_bytes()); + tree_bytes.append(&mut int_to_node(20).to_bytes()); + + tree_bytes.append(&mut NODE21.to_bytes()); + tree_bytes.append(&mut int_to_node(21).to_bytes()); + + // node with depth 1 could have index 0 or 1, but it has 2 + tree_bytes.append(&mut vec![1, 2]); + tree_bytes.append(&mut int_to_node(11).to_bytes()); + + assert!(PartialMerkleTree::read_from_bytes(&tree_bytes).is_err()); +} + /// Checks that addition of the path with different root will cause an error. #[test] fn err_add_path() { @@ -306,8 +459,5 @@ fn err_update_leaf() { let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap(); - assert!(pmt.update_leaf(NODE22, int_to_node(22)).is_err()); - assert!(pmt.update_leaf(NODE23, int_to_node(23)).is_err()); - assert!(pmt.update_leaf(NODE30, int_to_node(30)).is_err()); - assert!(pmt.update_leaf(NODE31, int_to_node(31)).is_err()); + assert!(pmt.update_leaf(8, *int_to_node(38)).is_err()); } diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index fdba5ed..d78be59 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -438,21 +438,21 @@ impl> From<&TieredSmt> for MerkleStore { impl> From for MerkleStore { fn from(values: T) -> Self { - let nodes = values.into_iter().chain(empty_hashes().into_iter()).collect(); + let nodes = values.into_iter().chain(empty_hashes()).collect(); Self { nodes } } } impl> FromIterator for MerkleStore { fn from_iter>(iter: I) -> Self { - let nodes = combine_nodes_with_empty_hashes(iter.into_iter()).collect(); + let nodes = combine_nodes_with_empty_hashes(iter).collect(); Self { nodes } } } impl> FromIterator<(RpoDigest, StoreNode)> for MerkleStore { fn from_iter>(iter: I) -> Self { - let nodes = iter.into_iter().chain(empty_hashes().into_iter()).collect(); + let nodes = iter.into_iter().chain(empty_hashes()).collect(); Self { nodes } } } @@ -553,5 +553,5 @@ fn combine_nodes_with_empty_hashes( }, ) }) - .chain(empty_hashes().into_iter()) + .chain(empty_hashes()) } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7804420..8059d26 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -2,10 +2,10 @@ use super::{utils::string::String, Word}; use core::fmt::{self, Write}; #[cfg(not(feature = "std"))] -pub use alloc::format; +pub use alloc::{format, vec}; #[cfg(feature = "std")] -pub use std::format; +pub use std::{format, vec}; mod kv_map; From 44e60e7228a2434e982782ea912cfb8155a148d6 Mon Sep 17 00:00:00 2001 From: frisitano Date: Fri, 30 Jun 2023 17:18:22 +0100 Subject: [PATCH 03/32] feat: introduce diff traits and objects --- src/merkle/store/mod.rs | 31 +++++-- src/utils/diff.rs | 16 ++++ src/utils/kv_map.rs | 182 +++++++++++++++++++++++++++++++++++++++- src/utils/mod.rs | 2 + 4 files changed, 224 insertions(+), 7 deletions(-) create mode 100644 src/utils/diff.rs diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index fdba5ed..f77d558 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -3,7 +3,10 @@ use super::{ MerklePathSet, MerkleTree, NodeIndex, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, ValuePath, Vec, }; -use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use crate::utils::{ + collections::{ApplyDiff, Diff, KvMapDiff}, + ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, +}; use core::borrow::Borrow; #[cfg(test)] @@ -438,21 +441,21 @@ impl> From<&TieredSmt> for MerkleStore { impl> From for MerkleStore { fn from(values: T) -> Self { - let nodes = values.into_iter().chain(empty_hashes().into_iter()).collect(); + let nodes = values.into_iter().chain(empty_hashes()).collect(); Self { nodes } } } impl> FromIterator for MerkleStore { fn from_iter>(iter: I) -> Self { - let nodes = combine_nodes_with_empty_hashes(iter.into_iter()).collect(); + let nodes = combine_nodes_with_empty_hashes(iter).collect(); Self { nodes } } } impl> FromIterator<(RpoDigest, StoreNode)> for MerkleStore { fn from_iter>(iter: I) -> Self { - let nodes = iter.into_iter().chain(empty_hashes().into_iter()).collect(); + let nodes = iter.into_iter().chain(empty_hashes()).collect(); Self { nodes } } } @@ -474,6 +477,24 @@ impl> Extend for MerkleStore { } } +// DiffT & ApplyDiffT TRAIT IMPLEMENTATION +// ================================================================================================ +impl> Diff for MerkleStore { + type DiffType = KvMapDiff; + + fn diff(&self, other: &Self) -> Self::DiffType { + self.nodes.diff(&other.nodes) + } +} + +impl> ApplyDiff for MerkleStore { + type DiffType = KvMapDiff; + + fn apply(&mut self, diff: Self::DiffType) { + self.nodes.apply(diff); + } +} + // SERIALIZATION // ================================================================================================ @@ -553,5 +574,5 @@ fn combine_nodes_with_empty_hashes( }, ) }) - .chain(empty_hashes().into_iter()) + .chain(empty_hashes()) } diff --git a/src/utils/diff.rs b/src/utils/diff.rs new file mode 100644 index 0000000..48c80b6 --- /dev/null +++ b/src/utils/diff.rs @@ -0,0 +1,16 @@ +/// A trait for computing the difference between two objects. +pub trait Diff { + type DiffType; + + /// Returns a `Self::DiffType` object that represents the difference between this object and + /// other. + fn diff(&self, other: &Self) -> Self::DiffType; +} + +/// A trait for applying the difference between two objects. +pub trait ApplyDiff { + type DiffType; + + /// Applies the provided changes described by [DiffType] to the object implementing this trait. + fn apply(&mut self, diff: Self::DiffType); +} diff --git a/src/utils/kv_map.rs b/src/utils/kv_map.rs index d9b453d..063a0a0 100644 --- a/src/utils/kv_map.rs +++ b/src/utils/kv_map.rs @@ -1,3 +1,4 @@ +use super::{collections::ApplyDiff, diff::Diff}; use core::cell::RefCell; use winter_utils::{ collections::{btree_map::IntoIter, BTreeMap, BTreeSet}, @@ -18,6 +19,7 @@ pub trait KvMap: self.len() == 0 } fn insert(&mut self, key: K, value: V) -> Option; + fn remove(&mut self, key: &K) -> Option; fn iter(&self) -> Box + '_>; } @@ -42,6 +44,10 @@ impl KvMap for BTreeMap { self.insert(key, value) } + fn remove(&mut self, key: &K) -> Option { + self.remove(key) + } + fn iter(&self) -> Box + '_> { Box::new(self.iter()) } @@ -56,8 +62,9 @@ impl KvMap for BTreeMap { /// /// The [RecordingMap] is composed of three parts: /// - `data`: which contains the current set of key-value pairs in the map. -/// - `updates`: which tracks keys for which values have been since the map was instantiated. -/// updates include both insertions and updates of values under existing keys. +/// - `updates`: which tracks keys for which values have been changed since the map was +/// instantiated. updates include both insertions, removals and updates of values under existing +/// keys. /// - `trace`: which contains the key-value pairs from the original data which have been accesses /// since the map was instantiated. #[derive(Debug, Default, Clone, Eq, PartialEq)] @@ -80,6 +87,13 @@ impl RecordingMap { } } + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + pub fn inner(&self) -> &BTreeMap { + &self.data + } + // FINALIZER // -------------------------------------------------------------------------------------------- @@ -148,6 +162,19 @@ impl KvMap for RecordingMap { }) } + /// Removes a key-value pair from the data set. + /// + /// If the key exists in the data set, the old value is returned. + fn remove(&mut self, key: &K) -> Option { + self.data.remove(key).map(|old_value| { + let new_update = self.updates.insert(key.clone()); + if new_update { + self.trace.borrow_mut().insert(key.clone(), old_value.clone()); + } + old_value + }) + } + // ITERATION // -------------------------------------------------------------------------------------------- @@ -180,6 +207,74 @@ impl IntoIterator for RecordingMap { } } +// KV MAP DIFF +// ================================================================================================ +/// [KvMapDiff] stores the difference between two key-value maps. +/// +/// The [KvMapDiff] is composed of two parts: +/// - `updates` - a map of key-value pairs that were updated in the second map compared to the +/// first map. This includes new key-value pairs. +/// - `removed` - a set of keys that were removed from the second map compared to the first map. +#[derive(Debug, Clone)] +pub struct KvMapDiff { + updated: BTreeMap, + removed: BTreeSet, +} + +impl KvMapDiff { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Creates a new [KvMapDiff] instance. + pub fn new() -> Self { + KvMapDiff { + updated: BTreeMap::new(), + removed: BTreeSet::new(), + } + } +} + +impl Default for KvMapDiff { + fn default() -> Self { + Self::new() + } +} + +impl> Diff for T { + type DiffType = KvMapDiff; + + fn diff(&self, other: &T) -> Self::DiffType { + let mut diff = KvMapDiff::default(); + for (k, v) in self.iter() { + if let Some(other_value) = other.get(k) { + if v != other_value { + diff.updated.insert(k.clone(), other_value.clone()); + } + } else { + diff.removed.insert(k.clone()); + } + } + for (k, v) in other.iter() { + if self.get(k).is_none() { + diff.updated.insert(k.clone(), v.clone()); + } + } + diff + } +} + +impl> ApplyDiff for T { + type DiffType = KvMapDiff; + + fn apply(&mut self, diff: Self::DiffType) { + for (k, v) in diff.updated { + self.insert(k, v); + } + for k in diff.removed { + self.remove(&k); + } + } +} + // TESTS // ================================================================================================ @@ -321,4 +416,87 @@ mod tests { let map = RecordingMap::new(ITEMS.to_vec()); assert!(!map.is_empty()); } + + #[test] + fn test_remove() { + let mut map = RecordingMap::new(ITEMS.to_vec()); + + // remove an item that exists + let key = 0; + let value = map.remove(&key).unwrap(); + assert_eq!(value, ITEMS[0].1); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 1); + + // add the item back and then remove it again + let key = 0; + let value = 0; + map.insert(key, value); + let value = map.remove(&key).unwrap(); + assert_eq!(value, 0); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 1); + + // remove an item that does not exist + let key = 100; + let value = map.remove(&key); + assert_eq!(value, None); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 1); + + // insert a new item and then remove it + let key = 100; + let value = 100; + map.insert(key, value); + let value = map.remove(&key).unwrap(); + assert_eq!(value, 100); + assert_eq!(map.len(), ITEMS.len() - 1); + assert_eq!(map.trace_len(), 1); + assert_eq!(map.updates_len(), 2); + + // convert the map into a proof + let proof = map.into_proof(); + + // check that the proof contains the expected values + for (key, value) in ITEMS.iter() { + match key { + 0 => assert_eq!(proof.get(key), Some(value)), + _ => assert_eq!(proof.get(key), None), + } + } + } + + #[test] + fn test_kv_map_diff() { + let mut initial_state = ITEMS.into_iter().collect::>(); + let mut map = RecordingMap::new(initial_state.clone()); + + // remove an item that exists + let key = 0; + let _value = map.remove(&key).unwrap(); + + // add a new item + let key = 100; + let value = 100; + map.insert(key, value); + + // update an existing item + let key = 1; + let value = 100; + map.insert(key, value); + + // compute a diff + let diff = initial_state.diff(map.inner()); + assert!(diff.updated.len() == 2); + assert!(diff.updated.iter().all(|(k, v)| [(100, 100), (1, 100)].contains(&(*k, *v)))); + assert!(diff.removed.len() == 1); + assert!(diff.removed.first() == Some(&0)); + + // apply the diff to the initial state and assert the contents are the same as the map + initial_state.apply(diff); + assert!(initial_state.iter().eq(map.iter())); + } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7804420..d775b61 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -7,6 +7,7 @@ pub use alloc::format; #[cfg(feature = "std")] pub use std::format; +mod diff; mod kv_map; // RE-EXPORTS @@ -17,6 +18,7 @@ pub use winter_utils::{ }; pub mod collections { + pub use super::diff::*; pub use super::kv_map::*; pub use winter_utils::collections::*; } From da2d08714d595879804674d6778ff4f81787d863 Mon Sep 17 00:00:00 2001 From: frisitano Date: Mon, 10 Jul 2023 12:53:39 +0100 Subject: [PATCH 04/32] feat: introduce TryApplyDiff and refactor RecordingMap finalizer --- src/merkle/delta.rs | 153 +++++++++++++++++++++++++++++++++++ src/merkle/mod.rs | 5 +- src/merkle/simple_smt/mod.rs | 30 ++++++- src/merkle/store/mod.rs | 80 +++++++++++++----- src/merkle/store/tests.rs | 2 +- src/utils/diff.rs | 19 ++++- src/utils/kv_map.rs | 22 ++--- 7 files changed, 276 insertions(+), 35 deletions(-) create mode 100644 src/merkle/delta.rs diff --git a/src/merkle/delta.rs b/src/merkle/delta.rs new file mode 100644 index 0000000..71b822a --- /dev/null +++ b/src/merkle/delta.rs @@ -0,0 +1,153 @@ +use super::{ + BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word, +}; +use crate::utils::collections::Diff; + +#[cfg(test)] +use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt}; + +// MERKLE STORE DELTA +// ================================================================================================ + +/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the +/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the +/// differences between the initial and final Merkle tree states. +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>); + +// MERKLE TREE DELTA +// ================================================================================================ + +/// [MerkleDelta] stores the differences between the initial and final Merkle tree states. +/// +/// The differences are represented as follows: +/// - depth: the depth of the merkle tree. +/// - cleared_slots: indexes of slots where values were set to [ZERO; 4]. +/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values. +#[cfg(not(test))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MerkleTreeDelta { + depth: u8, + cleared_slots: Vec, + updated_slots: Vec<(u64, Word)>, +} + +impl MerkleTreeDelta { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + pub fn new(depth: u8) -> Self { + Self { + depth, + cleared_slots: Vec::new(), + updated_slots: Vec::new(), + } + } + + // ACCESSORS + // -------------------------------------------------------------------------------------------- + /// Returns the depth of the Merkle tree the [MerkleDelta] is associated with. + pub fn depth(&self) -> u8 { + self.depth + } + + /// Returns the indexes of slots where values were set to [ZERO; 4]. + pub fn cleared_slots(&self) -> &[u64] { + &self.cleared_slots + } + + /// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values. + pub fn updated_slots(&self) -> &[(u64, Word)] { + &self.updated_slots + } + + // MODIFIERS + // -------------------------------------------------------------------------------------------- + /// Adds a slot index to the list of cleared slots. + pub fn add_cleared_slot(&mut self, index: u64) { + self.cleared_slots.push(index); + } + + /// Adds a slot index and a value to the list of updated slots. + pub fn add_updated_slot(&mut self, index: u64, value: Word) { + self.updated_slots.push((index, value)); + } +} + +/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by +/// their roots and depth. +pub fn merkle_tree_delta>( + tree_root_1: RpoDigest, + tree_root_2: RpoDigest, + depth: u8, + merkle_store: &MerkleStore, +) -> Result { + if tree_root_1 == tree_root_2 { + return Ok(MerkleTreeDelta::new(depth)); + } + + let tree_1_leaves: BTreeMap = + merkle_store.non_empty_leaves(tree_root_1, depth).collect(); + let tree_2_leaves: BTreeMap = + merkle_store.non_empty_leaves(tree_root_2, depth).collect(); + let diff = tree_1_leaves.diff(&tree_2_leaves); + + // TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec. + Ok(MerkleTreeDelta { + depth, + cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(), + updated_slots: diff + .updated + .into_iter() + .map(|(index, leaf)| (index.value(), *leaf)) + .collect(), + }) +} + +// INTERNALS +// -------------------------------------------------------------------------------------------- +#[cfg(test)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MerkleTreeDelta { + pub depth: u8, + pub cleared_slots: Vec, + pub updated_slots: Vec<(u64, Word)>, +} + +// MERKLE DELTA +// ================================================================================================ +#[test] +fn test_compute_merkle_delta() { + let entries = vec![ + (10, [Felt::new(0), Felt::new(1), Felt::new(2), Felt::new(3)]), + (15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]), + (20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]), + (31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]), + ]; + let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap(); + let mut store: MerkleStore = (&simple_smt).into(); + let root = simple_smt.root(); + + // add a new node + let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)]; + let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap(); + let root = store.set_node(root, new_index, new_value.into()).unwrap().root; + + // update an existing node + let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)]; + let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap(); + let root = store.set_node(root, update_idx, update_value.into()).unwrap().root; + + // remove a node + let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap(); + let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root; + + let merkle_delta = + merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap(); + let expected_merkle_delta = MerkleTreeDelta { + depth: simple_smt.depth(), + cleared_slots: vec![remove_idx.value()], + updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)], + }; + + assert_eq!(merkle_delta, expected_merkle_delta); +} diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 3e1c9d9..c49c004 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -1,6 +1,6 @@ use super::{ hash::rpo::{Rpo256, RpoDigest}, - utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, Vec}, + utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec}, Felt, StarkField, Word, WORD_SIZE, ZERO, }; use core::fmt; @@ -11,6 +11,9 @@ use core::fmt; mod empty_roots; pub use empty_roots::EmptySubtreeRoots; +mod delta; +pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta}; + mod index; pub use index::NodeIndex; diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index c8da302..542ab51 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -1,6 +1,6 @@ use super::{ - BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, - Rpo256, RpoDigest, Vec, Word, + BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta, + NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word, }; #[cfg(test)] @@ -275,3 +275,29 @@ impl BranchNode { Rpo256::merge(&[self.left, self.right]) } } + +// TRY APPLY DIFF +// ================================================================================================ +impl TryApplyDiff for SimpleSmt { + type Error = MerkleError; + type DiffType = MerkleTreeDelta; + + fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> { + if diff.depth() != self.depth() { + return Err(MerkleError::InvalidDepth { + expected: self.depth(), + provided: diff.depth(), + }); + } + + for slot in diff.cleared_slots() { + self.update_leaf(*slot, Self::EMPTY_VALUE)?; + } + + for (slot, value) in diff.updated_slots() { + self.update_leaf(*slot, *value)?; + } + + Ok(()) + } +} diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index f77d558..f250485 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -1,12 +1,9 @@ use super::{ - mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath, - MerklePathSet, MerkleTree, NodeIndex, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt, - TieredSmt, ValuePath, Vec, -}; -use crate::utils::{ - collections::{ApplyDiff, Diff, KvMapDiff}, - ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, + empty_roots::EMPTY_WORD, mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, + MerkleError, MerklePath, MerklePathSet, MerkleStoreDelta, MerkleTree, NodeIndex, RecordingMap, + RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, }; +use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use core::borrow::Borrow; #[cfg(test)] @@ -280,6 +277,37 @@ impl> MerkleStore { }) } + /// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root` + /// and `max_depth`. + pub fn non_empty_leaves( + &self, + root: RpoDigest, + max_depth: u8, + ) -> impl Iterator + '_ { + let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth); + let mut stack = Vec::new(); + stack.push((NodeIndex::new_unchecked(0, 0), root)); + + core::iter::from_fn(move || { + while let Some((index, node_hash)) = stack.pop() { + if index.depth() == max_depth { + return Some((index, node_hash)); + } + + if let Some(node) = self.nodes.get(&node_hash) { + if !empty_roots.contains(&node.left) { + stack.push((index.left_child(), node.left)); + } + if !empty_roots.contains(&node.right) { + stack.push((index.right_child(), node.right)); + } + } + } + + None + }) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -462,7 +490,6 @@ impl> FromIterator<(RpoDigest, StoreNode)> for Me // ITERATORS // ================================================================================================ - impl> Extend for MerkleStore { fn extend>(&mut self, iter: I) { self.nodes.extend(iter.into_iter().map(|info| { @@ -479,19 +506,34 @@ impl> Extend for MerkleStore { // DiffT & ApplyDiffT TRAIT IMPLEMENTATION // ================================================================================================ -impl> Diff for MerkleStore { - type DiffType = KvMapDiff; +impl> TryApplyDiff for MerkleStore { + type Error = MerkleError; + type DiffType = MerkleStoreDelta; - fn diff(&self, other: &Self) -> Self::DiffType { - self.nodes.diff(&other.nodes) - } -} + fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> { + for (root, delta) in diff.0 { + let mut root = root; + for cleared_slot in delta.cleared_slots() { + root = self + .set_node( + root, + NodeIndex::new(delta.depth(), *cleared_slot)?, + EMPTY_WORD.into(), + )? + .root; + } + for (updated_slot, updated_value) in delta.updated_slots() { + root = self + .set_node( + root, + NodeIndex::new(delta.depth(), *updated_slot)?, + (*updated_value).into(), + )? + .root; + } + } -impl> ApplyDiff for MerkleStore { - type DiffType = KvMapDiff; - - fn apply(&mut self, diff: Self::DiffType) { - self.nodes.apply(diff); + Ok(()) } } diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index c6f346f..5e5bce7 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -847,7 +847,7 @@ fn test_recorder() { // construct the proof let rec_map = recorder.into_inner(); - let proof = rec_map.into_proof(); + let (_, proof) = rec_map.finalize(); let merkle_store: MerkleStore = proof.into(); // make sure the proof contains all nodes from both trees diff --git a/src/utils/diff.rs b/src/utils/diff.rs index 48c80b6..97fc32f 100644 --- a/src/utils/diff.rs +++ b/src/utils/diff.rs @@ -1,16 +1,31 @@ /// A trait for computing the difference between two objects. pub trait Diff { + /// The type that describes the difference between two objects. type DiffType; - /// Returns a `Self::DiffType` object that represents the difference between this object and + /// Returns a [Self::DiffType] object that represents the difference between this object and /// other. fn diff(&self, other: &Self) -> Self::DiffType; } /// A trait for applying the difference between two objects. pub trait ApplyDiff { + /// The type that describes the difference between two objects. type DiffType; - /// Applies the provided changes described by [DiffType] to the object implementing this trait. + /// Applies the provided changes described by [Self::DiffType] to the object implementing this trait. fn apply(&mut self, diff: Self::DiffType); } + +/// A trait for applying the difference between two objects with the possibility of failure. +pub trait TryApplyDiff { + /// The type that describes the difference between two objects. + type DiffType; + + /// An error type that can be returned if the changes cannot be applied. + type Error; + + /// Applies the provided changes described by [Self::DiffType] to the object implementing this trait. + /// Returns an error if the changes cannot be applied. + fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>; +} diff --git a/src/utils/kv_map.rs b/src/utils/kv_map.rs index 063a0a0..3c92b56 100644 --- a/src/utils/kv_map.rs +++ b/src/utils/kv_map.rs @@ -97,10 +97,12 @@ impl RecordingMap { // FINALIZER // -------------------------------------------------------------------------------------------- - /// Consumes the [RecordingMap] and returns a [BTreeMap] containing the key-value pairs from - /// the initial data set that were read during recording. - pub fn into_proof(self) -> BTreeMap { - self.trace.take() + /// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first + /// element of the tuple is a map that represents the state of the map at the time `.finalize()` + /// is called. The second element contains the key-value pairs from the initial data set that + /// were read during recording. + pub fn finalize(self) -> (BTreeMap, BTreeMap) { + (self.data, self.trace.take()) } // TEST HELPERS @@ -217,8 +219,8 @@ impl IntoIterator for RecordingMap { /// - `removed` - a set of keys that were removed from the second map compared to the first map. #[derive(Debug, Clone)] pub struct KvMapDiff { - updated: BTreeMap, - removed: BTreeSet, + pub updated: BTreeMap, + pub removed: BTreeSet, } impl KvMapDiff { @@ -296,7 +298,7 @@ mod tests { } // convert the map into a proof - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // check that the proof contains the expected values for (key, value) in ITEMS.iter() { @@ -319,7 +321,7 @@ mod tests { } // convert the map into a proof - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // check that the proof contains the expected values for (key, _) in ITEMS.iter() { @@ -383,7 +385,7 @@ mod tests { // Note: The length reported by the proof will be different to the length originally // reported by the map. - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // length of the proof should be equal to get_items + 1. The extra item is the original // value at key = 4u64 @@ -458,7 +460,7 @@ mod tests { assert_eq!(map.updates_len(), 2); // convert the map into a proof - let proof = map.into_proof(); + let (_, proof) = map.finalize(); // check that the proof contains the expected values for (key, value) in ITEMS.iter() { From 8c749e473a9810f9a2800813b0652a6863f31df3 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Wed, 26 Jul 2023 12:10:01 -0700 Subject: [PATCH 05/32] chore: update blake3 dependency to v1.4 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7bffb94..5caa2be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ default = ["blake3/default", "std", "winter_crypto/default", "winter_math/defaul std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] [dependencies] -blake3 = { version = "1.3", default-features = false } +blake3 = { version = "1.4", default-features = false } winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } From 71b04d0734c5b31f8478e6dacce63936d75234f4 Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Fri, 23 Jun 2023 23:12:52 +0300 Subject: [PATCH 06/32] refactor: replace MerklePathSet with PartialMerkleTree --- CHANGELOG.md | 2 + README.md | 1 - src/merkle/mod.rs | 3 - src/merkle/partial_mt/mod.rs | 2 + src/merkle/path_set.rs | 408 ----------------------------------- src/merkle/store/mod.rs | 25 +-- src/merkle/store/tests.rs | 77 ++++--- src/utils/kv_map.rs | 4 +- 8 files changed, 53 insertions(+), 469 deletions(-) delete mode 100644 src/merkle/path_set.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 355d33f..5f833a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ ## 0.7.0 (TBD) +* Replaced `MerklePathSet` with `PartialMerkleTree` (#165). + ## 0.6.0 (2023-06-25) * [BREAKING] Added support for recording capabilities for `MerkleStore` (#162). diff --git a/README.md b/README.md index b0bbdfe..6274cea 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ For performance benchmarks of these hash functions and their comparison to other * `Mmr`: a Merkle mountain range structure designed to function as an append-only log. * `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64. -* `MerklePathSet`: a collection of Merkle authentication paths all resolving to the same root. The length of the paths can be at most 64. * `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data. * `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64. * `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values. diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index c49c004..9d20629 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -23,9 +23,6 @@ pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree}; mod path; pub use path::{MerklePath, RootPath, ValuePath}; -mod path_set; -pub use path_set::MerklePathSet; - mod simple_smt; pub use simple_smt::SimpleSmt; diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index ef87516..3f33535 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -340,6 +340,8 @@ impl PartialMerkleTree { } /// Updates value of the leaf at the specified index returning the old leaf value. + /// By default the specified index is assumed to belong to the deepest layer. If the considered + /// node does not belong to the tree, the first node on the way to the root will be changed. /// /// By default the specified index is assumed to belong to the deepest layer. If the considered /// node does not belong to the tree, the first node on the way to the root will be changed. diff --git a/src/merkle/path_set.rs b/src/merkle/path_set.rs deleted file mode 100644 index 169e073..0000000 --- a/src/merkle/path_set.rs +++ /dev/null @@ -1,408 +0,0 @@ -use super::{BTreeMap, MerkleError, MerklePath, NodeIndex, Rpo256, ValuePath, Vec}; -use crate::{hash::rpo::RpoDigest, Word}; - -// MERKLE PATH SET -// ================================================================================================ - -/// A set of Merkle paths. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct MerklePathSet { - root: RpoDigest, - total_depth: u8, - paths: BTreeMap, -} - -impl MerklePathSet { - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - - /// Returns an empty MerklePathSet. - pub fn new(depth: u8) -> Self { - let root = RpoDigest::default(); - let paths = BTreeMap::new(); - - Self { - root, - total_depth: depth, - paths, - } - } - - /// Appends the provided paths iterator into the set. - /// - /// Analogous to `[Self::add_path]`. - pub fn with_paths(self, paths: I) -> Result - where - I: IntoIterator, - { - paths.into_iter().try_fold(self, |mut set, (index, value, path)| { - set.add_path(index, value.into(), path)?; - Ok(set) - }) - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns the root to which all paths in this set resolve. - pub const fn root(&self) -> RpoDigest { - self.root - } - - /// Returns the depth of the Merkle tree implied by the paths stored in this set. - /// - /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. - pub const fn depth(&self) -> u8 { - self.total_depth - } - - /// Returns a node at the specified index. - /// - /// # Errors - /// Returns an error if: - /// * The specified index is not valid for the depth of structure. - /// * Requested node does not exist in the set. - pub fn get_node(&self, index: NodeIndex) -> Result { - if index.depth() != self.total_depth { - return Err(MerkleError::InvalidDepth { - expected: self.total_depth, - provided: index.depth(), - }); - } - - let parity = index.value() & 1; - let path_key = index.value() - parity; - self.paths - .get(&path_key) - .ok_or(MerkleError::NodeNotInSet(index)) - .map(|path| path[parity as usize]) - } - - /// Returns a leaf at the specified index. - /// - /// # Errors - /// * The specified index is not valid for the depth of the structure. - /// * Leaf with the requested path does not exist in the set. - pub fn get_leaf(&self, index: u64) -> Result { - let index = NodeIndex::new(self.depth(), index)?; - Ok(self.get_node(index)?.into()) - } - - /// Returns a Merkle path to the node at the specified index. The node itself is - /// not included in the path. - /// - /// # Errors - /// Returns an error if: - /// * The specified index is not valid for the depth of structure. - /// * Node of the requested path does not exist in the set. - pub fn get_path(&self, index: NodeIndex) -> Result { - if index.depth() != self.total_depth { - return Err(MerkleError::InvalidDepth { - expected: self.total_depth, - provided: index.depth(), - }); - } - - let parity = index.value() & 1; - let path_key = index.value() - parity; - let mut path = - self.paths.get(&path_key).cloned().ok_or(MerkleError::NodeNotInSet(index))?; - path.remove(parity as usize); - Ok(path) - } - - /// Returns all paths in this path set together with their indexes. - pub fn to_paths(&self) -> Vec<(u64, ValuePath)> { - let mut result = Vec::with_capacity(self.paths.len() * 2); - - for (&index, path) in self.paths.iter() { - // push path for the even index into the result - let path1 = ValuePath { - value: path[0], - path: MerklePath::new(path[1..].to_vec()), - }; - result.push((index, path1)); - - // push path for the odd index into the result - let mut path2 = path.clone(); - let leaf2 = path2.remove(1); - let path2 = ValuePath { - value: leaf2, - path: path2, - }; - result.push((index + 1, path2)); - } - - result - } - - // STATE MUTATORS - // -------------------------------------------------------------------------------------------- - - /// Adds the specified Merkle path to this [MerklePathSet]. The `index` and `value` parameters - /// specify the leaf node at which the path starts. - /// - /// # Errors - /// Returns an error if: - /// - The specified index is is not valid in the context of this Merkle path set (i.e., the - /// index implies a greater depth than is specified for this set). - /// - The specified path is not consistent with other paths in the set (i.e., resolves to a - /// different root). - pub fn add_path( - &mut self, - index_value: u64, - value: Word, - mut path: MerklePath, - ) -> Result<(), MerkleError> { - let mut index = NodeIndex::new(path.len() as u8, index_value)?; - if index.depth() != self.total_depth { - return Err(MerkleError::InvalidDepth { - expected: self.total_depth, - provided: index.depth(), - }); - } - - // update the current path - let parity = index_value & 1; - path.insert(parity as usize, value.into()); - - // traverse to the root, updating the nodes - let root = Rpo256::merge(&[path[0], path[1]]); - let root = path.iter().skip(2).copied().fold(root, |root, hash| { - index.move_up(); - Rpo256::merge(&index.build_node(root, hash)) - }); - - // if the path set is empty (the root is all ZEROs), set the root to the root of the added - // path; otherwise, the root of the added path must be identical to the current root - if self.root == RpoDigest::default() { - self.root = root; - } else if self.root != root { - return Err(MerkleError::ConflictingRoots([self.root, root].to_vec())); - } - - // finish updating the path - let path_key = index_value - parity; - self.paths.insert(path_key, path); - Ok(()) - } - - /// Replaces the leaf at the specified index with the provided value. - /// - /// # Errors - /// Returns an error if: - /// * Requested node does not exist in the set. - pub fn update_leaf(&mut self, base_index_value: u64, value: Word) -> Result<(), MerkleError> { - let mut index = NodeIndex::new(self.depth(), base_index_value)?; - let parity = index.value() & 1; - let path_key = index.value() - parity; - let path = match self.paths.get_mut(&path_key) { - Some(path) => path, - None => return Err(MerkleError::NodeNotInSet(index)), - }; - - // Fill old_hashes vector ----------------------------------------------------------------- - let mut current_index = index; - let mut old_hashes = Vec::with_capacity(path.len().saturating_sub(2)); - let mut root = Rpo256::merge(&[path[0], path[1]]); - for hash in path.iter().skip(2).copied() { - old_hashes.push(root); - current_index.move_up(); - let input = current_index.build_node(hash, root); - root = Rpo256::merge(&input); - } - - // Fill new_hashes vector ----------------------------------------------------------------- - path[index.is_value_odd() as usize] = value.into(); - - let mut new_hashes = Vec::with_capacity(path.len().saturating_sub(2)); - let mut new_root = Rpo256::merge(&[path[0], path[1]]); - for path_hash in path.iter().skip(2).copied() { - new_hashes.push(new_root); - index.move_up(); - let input = current_index.build_node(path_hash, new_root); - new_root = Rpo256::merge(&input); - } - - self.root = new_root; - - // update paths --------------------------------------------------------------------------- - for path in self.paths.values_mut() { - for i in (0..old_hashes.len()).rev() { - if path[i + 2] == old_hashes[i] { - path[i + 2] = new_hashes[i]; - break; - } - } - } - - Ok(()) - } -} - -// TESTS -// ================================================================================================ - -#[cfg(test)] -mod tests { - use super::*; - use crate::merkle::{int_to_leaf, int_to_node}; - - #[test] - fn get_root() { - let leaf0 = int_to_node(0); - let leaf1 = int_to_node(1); - let leaf2 = int_to_node(2); - let leaf3 = int_to_node(3); - - let parent0 = calculate_parent_hash(leaf0, 0, leaf1); - let parent1 = calculate_parent_hash(leaf2, 2, leaf3); - - let root_exp = calculate_parent_hash(parent0, 0, parent1); - - let set = super::MerklePathSet::new(2) - .with_paths([(0, leaf0, vec![leaf1, parent1].into())]) - .unwrap(); - - assert_eq!(set.root(), root_exp); - } - - #[test] - fn add_and_get_path() { - let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; - let hash_6 = int_to_node(6); - let index = 6_u64; - let depth = 3_u8; - let set = super::MerklePathSet::new(depth) - .with_paths([(index, hash_6, path_6.clone().into())]) - .unwrap(); - let stored_path_6 = set.get_path(NodeIndex::make(depth, index)).unwrap(); - - assert_eq!(path_6, *stored_path_6); - } - - #[test] - fn get_node() { - let path_6 = vec![int_to_node(7), int_to_node(45), int_to_node(123)]; - let hash_6 = int_to_node(6); - let index = 6_u64; - let depth = 3_u8; - let set = MerklePathSet::new(depth).with_paths([(index, hash_6, path_6.into())]).unwrap(); - - assert_eq!(int_to_node(6u64), set.get_node(NodeIndex::make(depth, index)).unwrap()); - } - - #[test] - fn update_leaf() { - let hash_4 = int_to_node(4); - let hash_5 = int_to_node(5); - let hash_6 = int_to_node(6); - let hash_7 = int_to_node(7); - let hash_45 = calculate_parent_hash(hash_4, 12u64, hash_5); - let hash_67 = calculate_parent_hash(hash_6, 14u64, hash_7); - - let hash_0123 = int_to_node(123); - - let path_6 = vec![hash_7, hash_45, hash_0123]; - let path_5 = vec![hash_4, hash_67, hash_0123]; - let path_4 = vec![hash_5, hash_67, hash_0123]; - - let index_6 = 6_u64; - let index_5 = 5_u64; - let index_4 = 4_u64; - let depth = 3_u8; - let mut set = MerklePathSet::new(depth) - .with_paths([ - (index_6, hash_6, path_6.into()), - (index_5, hash_5, path_5.into()), - (index_4, hash_4, path_4.into()), - ]) - .unwrap(); - - let new_hash_6 = int_to_leaf(100); - let new_hash_5 = int_to_leaf(55); - - set.update_leaf(index_6, new_hash_6).unwrap(); - let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap(); - let new_hash_67 = calculate_parent_hash(new_hash_6.into(), 14_u64, hash_7); - assert_eq!(new_hash_67, new_path_4[1]); - - set.update_leaf(index_5, new_hash_5).unwrap(); - let new_path_4 = set.get_path(NodeIndex::make(depth, index_4)).unwrap(); - let new_path_6 = set.get_path(NodeIndex::make(depth, index_6)).unwrap(); - let new_hash_45 = calculate_parent_hash(new_hash_5.into(), 13_u64, hash_4); - assert_eq!(new_hash_45, new_path_6[1]); - assert_eq!(RpoDigest::from(new_hash_5), new_path_4[0]); - } - - #[test] - fn depth_3_is_correct() { - let a = int_to_node(1); - let b = int_to_node(2); - let c = int_to_node(3); - let d = int_to_node(4); - let e = int_to_node(5); - let f = int_to_node(6); - let g = int_to_node(7); - let h = int_to_node(8); - - let i = Rpo256::merge(&[a, b]); - let j = Rpo256::merge(&[c, d]); - let k = Rpo256::merge(&[e, f]); - let l = Rpo256::merge(&[g, h]); - - let m = Rpo256::merge(&[i, j]); - let n = Rpo256::merge(&[k, l]); - - let root = Rpo256::merge(&[m, n]); - - let mut set = MerklePathSet::new(3); - - let value = b; - let index = 1; - let path = MerklePath::new([a, j, n].to_vec()); - set.add_path(index, value.into(), path).unwrap(); - assert_eq!(*value, set.get_leaf(index).unwrap()); - assert_eq!(root, set.root()); - - let value = e; - let index = 4; - let path = MerklePath::new([f, l, m].to_vec()); - set.add_path(index, value.into(), path).unwrap(); - assert_eq!(*value, set.get_leaf(index).unwrap()); - assert_eq!(root, set.root()); - - let value = a; - let index = 0; - let path = MerklePath::new([b, j, n].to_vec()); - set.add_path(index, value.into(), path).unwrap(); - assert_eq!(*value, set.get_leaf(index).unwrap()); - assert_eq!(root, set.root()); - - let value = h; - let index = 7; - let path = MerklePath::new([g, k, m].to_vec()); - set.add_path(index, value.into(), path).unwrap(); - assert_eq!(*value, set.get_leaf(index).unwrap()); - assert_eq!(root, set.root()); - } - - // HELPER FUNCTIONS - // -------------------------------------------------------------------------------------------- - - const fn is_even(pos: u64) -> bool { - pos & 1 == 0 - } - - /// Calculates the hash of the parent node by two sibling ones - /// - node — current node - /// - node_pos — position of the current node - /// - sibling — neighboring vertex in the tree - fn calculate_parent_hash(node: RpoDigest, node_pos: u64, sibling: RpoDigest) -> RpoDigest { - if is_even(node_pos) { - Rpo256::merge(&[node, sibling]) - } else { - Rpo256::merge(&[sibling, node]) - } - } -} diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index f250485..8d8b80a 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -1,7 +1,7 @@ use super::{ empty_roots::EMPTY_WORD, mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, - MerkleError, MerklePath, MerklePathSet, MerkleStoreDelta, MerkleTree, NodeIndex, RecordingMap, - RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, + MerkleError, MerklePath, MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, + RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, }; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use core::borrow::Borrow; @@ -351,20 +351,6 @@ impl> MerkleStore { Ok(()) } - /// Appends the provided [MerklePathSet] into the store. - /// - /// For further reference, check [MerkleStore::add_merkle_path]. - pub fn add_merkle_path_set( - &mut self, - path_set: &MerklePathSet, - ) -> Result { - let root = path_set.root(); - for (index, path) in path_set.to_paths() { - self.add_merkle_path(index, path.value, path.path)?; - } - Ok(root) - } - /// Sets a node to `value`. /// /// # Errors @@ -467,6 +453,13 @@ impl> From<&TieredSmt> for MerkleStore { } } +impl> From<&PartialMerkleTree> for MerkleStore { + fn from(value: &PartialMerkleTree) -> Self { + let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect(); + Self { nodes } + } +} + impl> From for MerkleStore { fn from(values: T) -> Self { let nodes = values.into_iter().chain(empty_hashes()).collect(); diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index 5e5bce7..dbc071e 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -1,10 +1,10 @@ use super::{ DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, - RecordingMerkleStore, RpoDigest, + PartialMerkleTree, RecordingMerkleStore, RpoDigest, }; use crate::{ hash::rpo::Rpo256, - merkle::{digests_to_words, int_to_leaf, int_to_node, MerklePathSet, MerkleTree, SimpleSmt}, + merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt}, Felt, Word, ONE, WORD_SIZE, ZERO, }; @@ -378,97 +378,96 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { let mut store = MerkleStore::default(); store.add_merkle_paths(paths.clone()).expect("the valid paths must work"); - let depth = 2; - let set = MerklePathSet::new(depth).with_paths(paths).unwrap(); + let pmt = PartialMerkleTree::with_paths(paths).unwrap(); // STORE LEAVES ARE CORRECT ============================================================== // checks the leaves in the store corresponds to the expected values assert_eq!( - store.get_node(set.root(), NodeIndex::make(set.depth(), 0)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)), Ok(VALUES4[0]), - "node 0 must be in the set" + "node 0 must be in the pmt" ); assert_eq!( - store.get_node(set.root(), NodeIndex::make(set.depth(), 1)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)), Ok(VALUES4[1]), - "node 1 must be in the set" + "node 1 must be in the pmt" ); assert_eq!( - store.get_node(set.root(), NodeIndex::make(set.depth(), 2)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)), Ok(VALUES4[2]), - "node 2 must be in the set" + "node 2 must be in the pmt" ); assert_eq!( - store.get_node(set.root(), NodeIndex::make(set.depth(), 3)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)), Ok(VALUES4[3]), - "node 3 must be in the set" + "node 3 must be in the pmt" ); - // STORE LEAVES MATCH SET ================================================================ - // sanity check the values returned by the store and the set + // STORE LEAVES MATCH PMT ================================================================ + // sanity check the values returned by the store and the pmt assert_eq!( - set.get_node(NodeIndex::make(set.depth(), 0)), - store.get_node(set.root(), NodeIndex::make(set.depth(), 0)), - "node 0 must be the same for both SparseMerkleTree and MerkleStore" + pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)), + "node 0 must be the same for both PartialMerkleTree and MerkleStore" ); assert_eq!( - set.get_node(NodeIndex::make(set.depth(), 1)), - store.get_node(set.root(), NodeIndex::make(set.depth(), 1)), - "node 1 must be the same for both SparseMerkleTree and MerkleStore" + pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)), + "node 1 must be the same for both PartialMerkleTree and MerkleStore" ); assert_eq!( - set.get_node(NodeIndex::make(set.depth(), 2)), - store.get_node(set.root(), NodeIndex::make(set.depth(), 2)), - "node 2 must be the same for both SparseMerkleTree and MerkleStore" + pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)), + "node 2 must be the same for both PartialMerkleTree and MerkleStore" ); assert_eq!( - set.get_node(NodeIndex::make(set.depth(), 3)), - store.get_node(set.root(), NodeIndex::make(set.depth(), 3)), - "node 3 must be the same for both SparseMerkleTree and MerkleStore" + pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)), + store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)), + "node 3 must be the same for both PartialMerkleTree and MerkleStore" ); // STORE MERKLE PATH MATCHS ============================================================== - // assert the merkle path returned by the store is the same as the one in the set - let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 0)).unwrap(); + // assert the merkle path returned by the store is the same as the one in the pmt + let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(); assert_eq!( VALUES4[0], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( - set.get_path(NodeIndex::make(set.depth(), 0)), + pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)), Ok(result.path), "merkle path for index 0 must be the same for the MerkleTree and MerkleStore" ); - let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 1)).unwrap(); + let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(); assert_eq!( VALUES4[1], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( - set.get_path(NodeIndex::make(set.depth(), 1)), + pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)), Ok(result.path), "merkle path for index 1 must be the same for the MerkleTree and MerkleStore" ); - let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 2)).unwrap(); + let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(); assert_eq!( VALUES4[2], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( - set.get_path(NodeIndex::make(set.depth(), 2)), + pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)), Ok(result.path), "merkle path for index 0 must be the same for the MerkleTree and MerkleStore" ); - let result = store.get_path(set.root(), NodeIndex::make(set.depth(), 3)).unwrap(); + let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(); assert_eq!( VALUES4[3], result.value, "Value for merkle path at index 0 must match leaf value" ); assert_eq!( - set.get_path(NodeIndex::make(set.depth(), 3)), + pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)), Ok(result.path), "merkle path for index 0 must be the same for the MerkleTree and MerkleStore" ); @@ -585,16 +584,16 @@ fn test_constructors() -> Result<(), MerkleError> { store2.add_merkle_path(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?; store2.add_merkle_path(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?; store2.add_merkle_path(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?; - let set = MerklePathSet::new(d).with_paths(paths).unwrap(); + let pmt = PartialMerkleTree::with_paths(paths).unwrap(); for key in [0, 1, 2, 3] { let index = NodeIndex::make(d, key); - let value_path1 = store1.get_path(set.root(), index)?; - let value_path2 = store2.get_path(set.root(), index)?; + let value_path1 = store1.get_path(pmt.root(), index)?; + let value_path2 = store2.get_path(pmt.root(), index)?; assert_eq!(value_path1, value_path2); let index = NodeIndex::make(d, key); - assert_eq!(set.get_path(index)?, value_path1.path); + assert_eq!(pmt.get_path(index)?, value_path1.path); } Ok(()) diff --git a/src/utils/kv_map.rs b/src/utils/kv_map.rs index 3c92b56..136eb55 100644 --- a/src/utils/kv_map.rs +++ b/src/utils/kv_map.rs @@ -326,8 +326,8 @@ mod tests { // check that the proof contains the expected values for (key, _) in ITEMS.iter() { match get_items.contains(key) { - true => assert_eq!(proof.contains_key(key), true), - false => assert_eq!(proof.contains_key(key), false), + true => assert!(proof.contains_key(key)), + false => assert!(!proof.contains_key(key)), } } } From 1578a9ee1fc4903df987b31944f31a71ecef4e73 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Thu, 27 Jul 2023 21:15:45 -0700 Subject: [PATCH 07/32] refactor: simplify TSTM leaf node hashing --- src/merkle/tiered_smt/mod.rs | 39 ++++++++++------------------------ src/merkle/tiered_smt/tests.rs | 10 +++------ 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 3269517..7ec8afe 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -1,6 +1,6 @@ use super::{ BTreeMap, BTreeSet, EmptySubtreeRoots, Felt, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, - Rpo256, RpoDigest, StarkField, Vec, Word, ZERO, + Rpo256, RpoDigest, StarkField, Vec, Word, }; use core::cmp; @@ -27,8 +27,8 @@ mod tests; /// /// To differentiate between internal and leaf nodes, node values are computed as follows: /// - Internal nodes: hash(left_child, right_child). -/// - Leaf node at depths 16, 32, or 64: hash(rem_key, value, domain=depth). -/// - Leaf node at depth 64: hash([rem_key_0, value_0, ..., rem_key_n, value_n, domain=64]). +/// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth). +/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n, domain=64]). /// /// Where rem_key is computed by replacing d most significant bits of the key with zeros where d /// is depth (i.e., for a leaf at depth 16, we replace 16 most significant bits of the key with 0). @@ -36,7 +36,7 @@ mod tests; pub struct TieredSmt { root: RpoDigest, nodes: BTreeMap, - upper_leaves: BTreeMap, // node_index |-> key map + upper_leaves: BTreeMap, // node_index |-> key bottom_leaves: BTreeMap, // leaves of depth 64 values: BTreeMap, } @@ -180,7 +180,7 @@ impl TieredSmt { let other_index = key_to_index(&other_key, depth); let other_value = *self.values.get(&other_key).expect("no value for other key"); self.upper_leaves.remove(&index).expect("other node key not in map"); - self.insert_node(other_index, other_key, other_value); + self.insert_leaf_node(other_index, other_key, other_value); // the new leaf also needs to move down to the same tier index = key_to_index(&key, depth); @@ -188,7 +188,7 @@ impl TieredSmt { } // insert the node and return the old value - self.insert_node(index, key, value); + self.insert_leaf_node(index, key, value); old_value } @@ -307,8 +307,9 @@ impl TieredSmt { /// Inserts the provided key-value pair at the specified index and updates the root of this /// Merkle tree by recomputing the path to the root. - fn insert_node(&mut self, mut index: NodeIndex, key: RpoDigest, value: Word) { + fn insert_leaf_node(&mut self, mut index: NodeIndex, key: RpoDigest, value: Word) { let depth = index.depth(); + debug_assert!(Self::TIER_DEPTHS.contains(&depth)); // insert the key into index-key map and compute the new value of the node let mut node = if index.depth() == Self::MAX_DEPTH { @@ -323,9 +324,8 @@ impl TieredSmt { // for the upper tiers, we just update the index-key map and compute the value of the // node self.upper_leaves.insert(index, key); - // the node value is computed as: hash(remaining_key || value, domain = depth) - let remaining_path = get_remaining_path(key, depth.into()); - Rpo256::merge_in_domain(&[remaining_path, value.into()], depth.into()) + // the node value is computed as: hash(key || value, domain = depth) + Rpo256::merge_in_domain(&[key, value.into()], depth.into()) }; // insert the node and update the path from the node to the root @@ -357,21 +357,6 @@ impl Default for TieredSmt { // HELPER FUNCTIONS // ================================================================================================ -/// Returns the remaining path for the specified key at the specified depth. -/// -/// Remaining path is computed by setting n most significant bits of the key to zeros, where n is -/// the specified depth. -fn get_remaining_path(key: RpoDigest, depth: u32) -> RpoDigest { - let mut key = Word::from(key); - key[3] = if depth == 64 { - ZERO - } else { - // remove `depth` bits from the most significant key element - ((key[3].as_int() << depth) >> depth).into() - }; - key.into() -} - /// Returns index for the specified key inserted at the specified depth. /// /// The value for the key is computed by taking n most significant bits from the most significant @@ -443,14 +428,12 @@ impl BottomLeaf { pub fn new(key: RpoDigest, value: Word) -> Self { let prefix = Word::from(key)[3].as_int(); let mut values = BTreeMap::new(); - let key = get_remaining_path(key, TieredSmt::MAX_DEPTH as u32); values.insert(key.into(), value); Self { prefix, values } } /// Adds a new key-value pair to this leaf. pub fn add_value(&mut self, key: RpoDigest, value: Word) { - let key = get_remaining_path(key, TieredSmt::MAX_DEPTH as u32); self.values.insert(key.into(), value); } @@ -476,7 +459,7 @@ impl BottomLeaf { Felt::new(key[0]), Felt::new(key[1]), Felt::new(key[2]), - Felt::new(self.prefix), + Felt::new(key[3]), ]); (key, *val) }) diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index fbf26ea..d8e6723 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -1,7 +1,6 @@ use super::{ super::{super::ONE, Felt, MerkleStore, WORD_SIZE, ZERO}, - get_remaining_path, EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, - Vec, Word, + EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word, }; #[test] @@ -411,8 +410,7 @@ fn get_init_root() -> RpoDigest { } fn build_leaf_node(key: RpoDigest, value: Word, depth: u8) -> RpoDigest { - let remaining_path = get_remaining_path(key, depth as u32); - Rpo256::merge_in_domain(&[remaining_path, value.into()], depth.into()) + Rpo256::merge_in_domain(&[key, value.into()], depth.into()) } fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest { @@ -420,9 +418,7 @@ fn build_bottom_leaf_node(keys: &[RpoDigest], values: &[Word]) -> RpoDigest { let mut elements = Vec::with_capacity(keys.len()); for (key, val) in keys.iter().zip(values.iter()) { - let mut key = Word::from(key); - key[3] = ZERO; - elements.extend_from_slice(&key); + elements.extend_from_slice(key.as_elements()); elements.extend_from_slice(val.as_slice()); } From 1bb75e85dd0dbac075697076e05e90eb2243829f Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Fri, 28 Jul 2023 01:44:06 -0700 Subject: [PATCH 08/32] feat: implement value removal in TSMT --- src/merkle/tiered_smt/mod.rs | 399 ++++++++++------------ src/merkle/tiered_smt/nodes.rs | 356 ++++++++++++++++++++ src/merkle/tiered_smt/tests.rs | 209 ++++++++++++ src/merkle/tiered_smt/values.rs | 580 ++++++++++++++++++++++++++++++++ 4 files changed, 1320 insertions(+), 224 deletions(-) create mode 100644 src/merkle/tiered_smt/nodes.rs create mode 100644 src/merkle/tiered_smt/values.rs diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 7ec8afe..52a3f89 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -1,9 +1,15 @@ use super::{ - BTreeMap, BTreeSet, EmptySubtreeRoots, Felt, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, + BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, StarkField, Vec, Word, }; use core::cmp; +mod nodes; +use nodes::NodeStore; + +mod values; +use values::ValueStore; + #[cfg(test)] mod tests; @@ -18,27 +24,22 @@ mod tests; /// of depth 64 (i.e., leaves at depth 64 are set to [ZERO; 4]). As non-empty values are inserted /// into the tree they are added to the first available tier. /// -/// For example, when the first key-value is inserted, it will be stored in a node at depth 16 -/// such that the first 16 bits of the key determine the position of the node at depth 16. If -/// another value with a key sharing the same 16-bit prefix is inserted, both values move into -/// the next tier (depth 32). This process is repeated until values end up at tier 64. If multiple -/// values have keys with a common 64-bit prefix, such key-value pairs are stored in a sorted list -/// at the last tier (depth = 64). +/// For example, when the first key-value pair is inserted, it will be stored in a node at depth +/// 16 such that the 16 most significant bits of the key determine the position of the node at +/// depth 16. If another value with a key sharing the same 16-bit prefix is inserted, both values +/// move into the next tier (depth 32). This process is repeated until values end up at the bottom +/// tier (depth 64). If multiple values have keys with a common 64-bit prefix, such key-value pairs +/// are stored in a sorted list at the bottom tier. /// /// To differentiate between internal and leaf nodes, node values are computed as follows: /// - Internal nodes: hash(left_child, right_child). /// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth). /// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n, domain=64]). -/// -/// Where rem_key is computed by replacing d most significant bits of the key with zeros where d -/// is depth (i.e., for a leaf at depth 16, we replace 16 most significant bits of the key with 0). #[derive(Debug, Clone, PartialEq, Eq)] pub struct TieredSmt { root: RpoDigest, - nodes: BTreeMap, - upper_leaves: BTreeMap, // node_index |-> key - bottom_leaves: BTreeMap, // leaves of depth 64 - values: BTreeMap, + nodes: NodeStore, + values: ValueStore, } impl TieredSmt { @@ -106,8 +107,7 @@ impl TieredSmt { /// when a leaf node with the same index prefix exists at a tier higher than the requested /// node. pub fn get_node(&self, index: NodeIndex) -> Result { - self.validate_node_access(index)?; - Ok(self.get_node_unchecked(&index)) + self.nodes.get_node(index) } /// Returns a Merkle path from the node at the specified index to the root. @@ -120,17 +120,8 @@ impl TieredSmt { /// - The node with the specified index does not exists in the Merkle tree. This is possible /// when a leaf node with the same index prefix exists at a tier higher than the node to /// which the path is requested. - pub fn get_path(&self, mut index: NodeIndex) -> Result { - self.validate_node_access(index)?; - - let mut path = Vec::with_capacity(index.depth() as usize); - for _ in 0..index.depth() { - let node = self.get_node_unchecked(&index.sibling()); - path.push(node); - index.move_up(); - } - - Ok(path.into()) + pub fn get_path(&self, index: NodeIndex) -> Result { + self.nodes.get_path(index) } /// Returns the value associated with the specified key. @@ -151,44 +142,72 @@ impl TieredSmt { /// /// If the value for the specified key was not previously set, [ZERO; 4] is returned. pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word { - // insert the value into the key-value map, and if nothing has changed, return - let old_value = self.values.insert(key, value).unwrap_or(Self::EMPTY_VALUE); - if old_value == value { - return old_value; + // if an empty value is being inserted, remove the leaf node to make it look as if the + // value was never inserted + if value == Self::EMPTY_VALUE { + return self.remove_leaf_node(key); } + // insert the value into the value store, and if nothing has changed, return + let (old_value, is_update) = match self.values.insert(key, value) { + Some(old_value) => { + if old_value == value { + return old_value; + } + (old_value, true) + } + None => (Self::EMPTY_VALUE, false), + }; + // determine the index for the value node; this index could have 3 different meanings: // - it points to a root of an empty subtree (excluding depth = 64); in this case, we can // replace the node with the value node immediately. // - it points to a node at the bottom tier (i.e., depth = 64); in this case, we need to - // process bottom-tier insertion which will be handled by insert_node(). - // - it points to a leaf node; this node could be a node with the same key or a different - // key with a common prefix; in the latter case, we'll need to move the leaf to a lower - // tier; for this scenario the `leaf_key` will contain the key of the leaf node - let (mut index, leaf_key) = self.get_insert_location(&key); + // process bottom-tier insertion which will be handled by insert_leaf_node(). + // - it points to an existing leaf node; this node could be a node with the same key or a + // different key with a common prefix; in the latter case, we'll need to move the leaf + // to a lower tier + let (index, leaf_exists) = self.nodes.get_insert_location(&key); + debug_assert!(!is_update || leaf_exists); - // if the returned index points to a leaf, and this leaf is for a different key, we need - // to move the leaf to a lower tier - if let Some(other_key) = leaf_key { - if other_key != key { - // determine how far down the tree should we move the existing leaf - let common_prefix_len = get_common_prefix_tier(&key, &other_key); - let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH); + // if the returned index points to a leaf, and this leaf is for a different key (i.e., we + // are not updating a value for an existing key), we need to replace this leaf with a tree + // containing leaves for both the old and the new key-value pairs + if leaf_exists && !is_update { + // get the key-value pair for the key with the same prefix; since the key-value + // pair has already been inserted into the value store, we need to filter it out + // when looking for the other key-value pair + let (other_key, other_value) = self + .values + .get_first_filtered(index_to_prefix(&index), &key) + .expect("other key-value pair not found"); - // move the leaf to the new location; this requires first removing the existing - // index, re-computing node value, and inserting the node at a new location - let other_index = key_to_index(&other_key, depth); - let other_value = *self.values.get(&other_key).expect("no value for other key"); - self.upper_leaves.remove(&index).expect("other node key not in map"); - self.insert_leaf_node(other_index, other_key, other_value); + // determine how far down the tree should we move the leaves + let common_prefix_len = get_common_prefix_tier(&key, other_key); + let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH); - // the new leaf also needs to move down to the same tier - index = key_to_index(&key, depth); - } + // compute node locations for new and existing key-value paris + let new_index = key_to_index(&key, depth); + let other_index = key_to_index(other_key, depth); + + // compute node values for the new and existing key-value pairs + let new_node = self.build_leaf_node(new_index, key, value); + let other_node = self.build_leaf_node(other_index, *other_key, *other_value); + + // replace the leaf located at index with a subtree containing nodes for new and + // existing key-value paris + self.root = self.nodes.replace_leaf_with_subtree( + index, + [(new_index, new_node), (other_index, other_node)], + ); + } else { + // if the returned index points to an empty subtree, or a leaf with the same key (i.e., + // we are performing an update), or a leaf is at the bottom tier, compute its node + // value and do a simple insert + let node = self.build_leaf_node(index, key, value); + self.root = self.nodes.insert_leaf_node(index, node); } - // insert the node and return the old value - self.insert_leaf_node(index, key, value); old_value } @@ -200,156 +219,114 @@ impl TieredSmt { /// /// The iterator order is unspecified. pub fn inner_nodes(&self) -> impl Iterator + '_ { - self.nodes.iter().filter_map(|(index, node)| { - if is_inner_node(index) { - Some(InnerNodeInfo { - value: *node, - left: self.get_node_unchecked(&index.left_child()), - right: self.get_node_unchecked(&index.right_child()), - }) - } else { - None - } - }) + self.nodes.inner_nodes() } - /// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt]. - /// - /// Each yielded item is a (node, key, value) tuple where key is a full un-truncated key (i.e., - /// with key[3] element unmodified). + /// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt] + /// where each yielded item is a (node, key, value) tuple. /// /// The iterator order is unspecified. pub fn upper_leaves(&self) -> impl Iterator + '_ { - self.upper_leaves.iter().map(|(index, key)| { - let node = self.get_node_unchecked(index); - let value = self.get_value(*key); - (node, *key, value) + self.nodes.upper_leaves().map(|(index, node)| { + let key_prefix = index_to_prefix(index); + let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found"); + (*node, *key, *value) }) } /// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt]. /// /// Each yielded item consists of the hash of the leaf and its contents, where contents is - /// a vector containing key-value pairs of entries storied in this leaf. Note that keys are - /// un-truncated keys (i.e., with key[3] element unmodified). + /// a vector containing key-value pairs of entries storied in this leaf. /// /// The iterator order is unspecified. pub fn bottom_leaves(&self) -> impl Iterator)> + '_ { - self.bottom_leaves.values().map(|leaf| (leaf.hash(), leaf.contents())) + self.nodes.bottom_leaves().map(|(&prefix, node)| { + let values = self.values.get_all(prefix).expect("bottom leaf not found"); + (*node, values) + }) } // HELPER METHODS // -------------------------------------------------------------------------------------------- - /// Checks if the specified index is valid in the context of this Merkle tree. + /// Removes the node holding the key-value pair for the specified key from this tree, and + /// returns the value associated with the specified key. /// - /// # Errors - /// Returns an error if: - /// - The specified index depth is 0 or greater than 64. - /// - The node for the specified index does not exists in the Merkle tree. This is possible - /// when an ancestors of the specified index is a leaf node. - fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> { - if index.is_root() { - return Err(MerkleError::DepthTooSmall(index.depth())); - } else if index.depth() > Self::MAX_DEPTH { - return Err(MerkleError::DepthTooBig(index.depth() as u64)); + /// If no value was associated with the specified key, [ZERO; 4] is returned. + fn remove_leaf_node(&mut self, key: RpoDigest) -> Word { + // remove the key-value pair from the value store; if no value was associated with the + // specified key, return. + let old_value = match self.values.remove(&key) { + Some(old_value) => old_value, + None => return Self::EMPTY_VALUE, + }; + + // determine the location of the leaf holding the key-value pair to be removed + let (index, leaf_exists) = self.nodes.get_insert_location(&key); + debug_assert!(index.depth() == Self::MAX_DEPTH || leaf_exists); + + // if the leaf is at the bottom tier and after removing the key-value pair from it, the + // leaf is still not empty, just recompute its hash and update the leaf node. + if index.depth() == Self::MAX_DEPTH { + if let Some(values) = self.values.get_all(index.value()) { + let node = hash_bottom_leaf(&values); + self.root = self.nodes.update_leaf_node(index, node); + return old_value; + }; + } + + // if the removed key-value pair has a lone sibling at the current tier with a root at + // higher tier, we need to move the sibling to a higher tier + if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) { + // determine the current index of the sibling node + let sib_index = key_to_index(sib_key, index.depth()); + debug_assert!(sib_index.depth() > new_sib_index.depth()); + + // compute node value for the new location of the sibling leaf and replace the subtree + // with this leaf node + let node = self.build_leaf_node(new_sib_index, *sib_key, *sib_val); + let new_sib_depth = new_sib_index.depth(); + self.root = self.nodes.replace_subtree_with_leaf(index, sib_index, new_sib_depth, node); } else { - // make sure that there are no leaf nodes in the ancestors of the index; since leaf - // nodes can live at specific depth, we just need to check these depths. - let tier = get_index_tier(&index); - let mut tier_index = index; - for &depth in Self::TIER_DEPTHS[..tier].iter().rev() { - tier_index.move_up_to(depth); - if self.upper_leaves.contains_key(&tier_index) { - return Err(MerkleError::NodeNotInSet(index)); - } - } + // if the removed key-value pair did not have a sibling at the current tier with a + // root at higher tiers, just clear the leaf node + self.root = self.nodes.clear_leaf_node(index); } - Ok(()) + old_value } - /// Returns a node at the specified index. If the node does not exist at this index, a root - /// for an empty subtree at the index's depth is returned. + /// Builds and returns a leaf node value for the node located as the specified index. /// - /// Unlike [TieredSmt::get_node()] this does not perform any checks to verify that the returned - /// node is valid in the context of this tree. - fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest { - match self.nodes.get(index) { - Some(node) => *node, - None => EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[index.depth() as usize], - } - } - - /// Returns an index at which a node for the specified key should be inserted. If a leaf node - /// already exists at that index, returns the key associated with that leaf node. - /// - /// In case the index falls into the bottom tier (depth = 64), leaf node key is not returned - /// as the bottom tier may contain multiple key-value pairs in the same leaf. - fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, Option) { - // traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if - // a node at any of the tiers is either a leaf or a root of an empty subtree. - let mse = Word::from(key)[3].as_int(); - for depth in (Self::TIER_DEPTHS[0]..Self::MAX_DEPTH).step_by(Self::TIER_SIZE as usize) { - let index = NodeIndex::new_unchecked(depth, mse >> (Self::MAX_DEPTH - depth)); - if let Some(leaf_key) = self.upper_leaves.get(&index) { - return (index, Some(*leaf_key)); - } else if !self.nodes.contains_key(&index) { - return (index, None); - } - } - - // if we got here, that means all of the nodes checked so far are internal nodes, and - // the new node would need to be inserted in the bottom tier. - let index = NodeIndex::new_unchecked(Self::MAX_DEPTH, mse); - (index, None) - } - - /// Inserts the provided key-value pair at the specified index and updates the root of this - /// Merkle tree by recomputing the path to the root. - fn insert_leaf_node(&mut self, mut index: NodeIndex, key: RpoDigest, value: Word) { + /// This method assumes that the key-value pair for the node has already been inserted into + /// the value store, however, for depths 16, 32, and 48, the node is computed directly from + /// the passed-in values (for depth 64, the value store is queried to get all the key-value + /// pairs located at the specified index). + fn build_leaf_node(&self, index: NodeIndex, key: RpoDigest, value: Word) -> RpoDigest { let depth = index.depth(); debug_assert!(Self::TIER_DEPTHS.contains(&depth)); // insert the key into index-key map and compute the new value of the node - let mut node = if index.depth() == Self::MAX_DEPTH { + if index.depth() == Self::MAX_DEPTH { // for the bottom tier, we add the key-value pair to the existing leaf, or create a // new leaf with this key-value pair - self.bottom_leaves - .entry(index.value()) - .and_modify(|leaves| leaves.add_value(key, value)) - .or_insert(BottomLeaf::new(key, value)) - .hash() + let values = self.values.get_all(index.value()).unwrap(); + hash_bottom_leaf(&values) } else { - // for the upper tiers, we just update the index-key map and compute the value of the - // node - self.upper_leaves.insert(index, key); - // the node value is computed as: hash(key || value, domain = depth) - Rpo256::merge_in_domain(&[key, value.into()], depth.into()) - }; - - // insert the node and update the path from the node to the root - for _ in 0..index.depth() { - self.nodes.insert(index, node); - let sibling = self.get_node_unchecked(&index.sibling()); - node = Rpo256::merge(&index.build_node(node, sibling)); - index.move_up(); + debug_assert_eq!(self.values.get_first(index_to_prefix(&index)), Some(&(key, value))); + hash_upper_leaf(key, value, depth) } - - // update the root - self.nodes.insert(NodeIndex::root(), node); - self.root = node; } } impl Default for TieredSmt { fn default() -> Self { + let root = EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0]; Self { - root: EmptySubtreeRoots::empty_hashes(Self::MAX_DEPTH)[0], - nodes: BTreeMap::new(), - upper_leaves: BTreeMap::new(), - bottom_leaves: BTreeMap::new(), - values: BTreeMap::new(), + root, + nodes: NodeStore::new(root), + values: ValueStore::default(), } } } @@ -357,12 +334,23 @@ impl Default for TieredSmt { // HELPER FUNCTIONS // ================================================================================================ +/// Returns the value representing the 64 most significant bits of the specified key. +fn get_key_prefix(key: &RpoDigest) -> u64 { + Word::from(key)[3].as_int() +} + +/// Returns the index value shifted to be in the most significant bit positions of the returned +/// u64 value. +fn index_to_prefix(index: &NodeIndex) -> u64 { + index.value() << (TieredSmt::MAX_DEPTH - index.depth()) +} + /// Returns index for the specified key inserted at the specified depth. /// /// The value for the key is computed by taking n most significant bits from the most significant /// element of the key, where n is the specified depth. fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { - let mse = Word::from(key)[3].as_int(); + let mse = get_key_prefix(key); let value = match depth { 16 | 32 | 48 | 64 => mse >> ((TieredSmt::MAX_DEPTH - depth) as u32), _ => unreachable!("invalid depth: {depth}"), @@ -379,8 +367,8 @@ fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { /// - returns 16 if the common prefix is between 16 and 31 bits. /// - returns 0 if the common prefix is fewer than 16 bits. fn get_common_prefix_tier(key1: &RpoDigest, key2: &RpoDigest) -> u8 { - let e1 = Word::from(key1)[3].as_int(); - let e2 = Word::from(key2)[3].as_int(); + let e1 = get_key_prefix(key1); + let e2 = get_key_prefix(key2); let ex = (e1 ^ e2).leading_zeros() as u8; (ex / 16) * 16 } @@ -402,67 +390,30 @@ const fn get_index_tier(index: &NodeIndex) -> usize { } } -/// Returns true if the specified index is an index for an inner node (i.e., the depth is not 16, -/// 32, 48, or 64). -const fn is_inner_node(index: &NodeIndex) -> bool { - !matches!(index.depth(), 16 | 32 | 48 | 64) +/// Returns true if the specified index is an index for an leaf node (i.e., the depth is 16, 32, +/// 48, or 64). +const fn is_leaf_node(index: &NodeIndex) -> bool { + matches!(index.depth(), 16 | 32 | 48 | 64) } -// BOTTOM LEAF -// ================================================================================================ - -/// Stores contents of the bottom leaf (i.e., leaf at depth = 64) in a [TieredSmt]. +/// Computes node value for leaves at tiers 16, 32, or 48. /// -/// Bottom leaf can contain one or more key-value pairs all sharing the same 64-bit key prefix. -/// The values are sorted by key to make sure the structure of the leaf is independent of the -/// insertion order. This guarantees that a leaf with the same set of key-value pairs always has -/// the same hash value. -#[derive(Debug, Clone, PartialEq, Eq)] -struct BottomLeaf { - prefix: u64, - values: BTreeMap<[u64; 4], Word>, +/// Node value is computed as: hash(key || value, domain = depth). +pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest { + const NUM_UPPER_TIERS: usize = TieredSmt::TIER_DEPTHS.len() - 1; + debug_assert!(TieredSmt::TIER_DEPTHS[..NUM_UPPER_TIERS].contains(&depth)); + Rpo256::merge_in_domain(&[key, value.into()], depth.into()) } -impl BottomLeaf { - /// Returns a new [BottomLeaf] with a single key-value pair added. - pub fn new(key: RpoDigest, value: Word) -> Self { - let prefix = Word::from(key)[3].as_int(); - let mut values = BTreeMap::new(); - values.insert(key.into(), value); - Self { prefix, values } - } - - /// Adds a new key-value pair to this leaf. - pub fn add_value(&mut self, key: RpoDigest, value: Word) { - self.values.insert(key.into(), value); - } - - /// Computes a hash of this leaf. - pub fn hash(&self) -> RpoDigest { - let mut elements = Vec::with_capacity(self.values.len() * 2); - for (key, val) in self.values.iter() { - key.iter().for_each(|&v| elements.push(Felt::new(v))); - elements.extend_from_slice(val.as_slice()); - } - // TODO: hash in domain - Rpo256::hash_elements(&elements) - } - - /// Returns contents of this leaf as a vector of (key, value) pairs. - /// - /// The keys are returned in their un-truncated form. - pub fn contents(&self) -> Vec<(RpoDigest, Word)> { - self.values - .iter() - .map(|(key, val)| { - let key = RpoDigest::from([ - Felt::new(key[0]), - Felt::new(key[1]), - Felt::new(key[2]), - Felt::new(key[3]), - ]); - (key, *val) - }) - .collect() +/// Computes node value for leaves at the bottom tier (depth 64). +/// +/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n, domain=64]). +pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest { + let mut elements = Vec::with_capacity(values.len() * 8); + for (key, val) in values.iter() { + elements.extend_from_slice(key.as_elements()); + elements.extend_from_slice(val.as_slice()); } + // TODO: hash in domain + Rpo256::hash_elements(&elements) } diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs new file mode 100644 index 0000000..42bad5e --- /dev/null +++ b/src/merkle/tiered_smt/nodes.rs @@ -0,0 +1,356 @@ +use super::{ + get_index_tier, get_key_prefix, is_leaf_node, BTreeMap, BTreeSet, EmptySubtreeRoots, + InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, +}; + +// CONSTANTS +// ================================================================================================ + +/// The number of levels between tiers. +const TIER_SIZE: u8 = super::TieredSmt::TIER_SIZE; + +/// Depths at which leaves can exist in a tiered SMT. +const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS; + +/// Maximum node depth. This is also the bottom tier of the tree. +const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; + +// NODE STORE +// ================================================================================================ + +/// A store of nodes for a Tiered Sparse Merkle tree. +/// +/// The store contains information about all nodes as well as information about which of the nodes +/// represent leaf nodes in a Tiered Sparse Merkle tree. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NodeStore { + nodes: BTreeMap, + upper_leaves: BTreeSet, + bottom_leaves: BTreeSet, +} + +impl NodeStore { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Returns a new instance of [NodeStore] instantiated with the specified root node. + /// + /// Root node is assumed to be a root of an empty sparse Merkle tree. + pub fn new(root_node: RpoDigest) -> Self { + let mut nodes = BTreeMap::default(); + nodes.insert(NodeIndex::root(), root_node); + + Self { + nodes, + upper_leaves: BTreeSet::default(), + bottom_leaves: BTreeSet::default(), + } + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns a node at the specified index. + /// + /// # Errors + /// Returns an error if: + /// - The specified index depth is 0 or greater than 64. + /// - The node with the specified index does not exists in the Merkle tree. This is possible + /// when a leaf node with the same index prefix exists at a tier higher than the requested + /// node. + pub fn get_node(&self, index: NodeIndex) -> Result { + self.validate_node_access(index)?; + Ok(self.get_node_unchecked(&index)) + } + + /// Returns a Merkle path from the node at the specified index to the root. + /// + /// The node itself is not included in the path. + /// + /// # Errors + /// Returns an error if: + /// - The specified index depth is 0 or greater than 64. + /// - The node with the specified index does not exists in the Merkle tree. This is possible + /// when a leaf node with the same index prefix exists at a tier higher than the node to + /// which the path is requested. + pub fn get_path(&self, mut index: NodeIndex) -> Result { + self.validate_node_access(index)?; + + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let node = self.get_node_unchecked(&index.sibling()); + path.push(node); + index.move_up(); + } + + Ok(path.into()) + } + + /// Returns an index at which a leaf node for the specified key should be inserted. + /// + /// The second value in the returned tuple is set to true if the node at the returned index + /// is already a leaf node, excluding leaves at the bottom tier (i.e., if the leaf is at the + /// bottom tier, false is returned). + pub fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, bool) { + // traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if + // a node at any of the tiers is either a leaf or a root of an empty subtree. + let mse = get_key_prefix(key); + for depth in (TIER_DEPTHS[0]..MAX_DEPTH).step_by(TIER_SIZE as usize) { + let index = NodeIndex::new_unchecked(depth, mse >> (MAX_DEPTH - depth)); + if self.upper_leaves.contains(&index) { + return (index, true); + } else if !self.nodes.contains_key(&index) { + return (index, false); + } + } + + // if we got here, that means all of the nodes checked so far are internal nodes, and + // the new node would need to be inserted in the bottom tier. + let index = NodeIndex::new_unchecked(MAX_DEPTH, mse); + (index, false) + } + + // ITERATORS + // -------------------------------------------------------------------------------------------- + + /// Returns an iterator over all inner nodes of the Tiered Sparse Merkle tree (i.e., nodes not + /// at depths 16 32, 48, or 64). + /// + /// The iterator order is unspecified. + pub fn inner_nodes(&self) -> impl Iterator + '_ { + self.nodes.iter().filter_map(|(index, node)| { + if !is_leaf_node(index) { + Some(InnerNodeInfo { + value: *node, + left: self.get_node_unchecked(&index.left_child()), + right: self.get_node_unchecked(&index.right_child()), + }) + } else { + None + } + }) + } + + /// Returns an iterator over the upper leaves (i.e., leaves with depths 16, 32, 48) of the + /// Tiered Sparse Merkle tree. + pub fn upper_leaves(&self) -> impl Iterator { + self.upper_leaves.iter().map(|index| (index, &self.nodes[index])) + } + + /// Returns an iterator over the bottom leaves (i.e., leaves with depth 64) of the Tiered + /// Sparse Merkle tree. + pub fn bottom_leaves(&self) -> impl Iterator { + self.bottom_leaves.iter().map(|value| { + let index = NodeIndex::new_unchecked(MAX_DEPTH, *value); + (value, &self.nodes[&index]) + }) + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Replaces the leaf node at the specified index with a tree consisting of two leaves located + /// at the specified indexes. Recomputes and returns the new root. + pub fn replace_leaf_with_subtree( + &mut self, + leaf_index: NodeIndex, + subtree_leaves: [(NodeIndex, RpoDigest); 2], + ) -> RpoDigest { + debug_assert!(is_leaf_node(&leaf_index)); + debug_assert!(is_leaf_node(&subtree_leaves[0].0)); + debug_assert!(is_leaf_node(&subtree_leaves[1].0)); + debug_assert!(!is_empty_root(&subtree_leaves[0].1)); + debug_assert!(!is_empty_root(&subtree_leaves[1].1)); + debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth()); + debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth()); + + self.upper_leaves.remove(&leaf_index); + self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1); + self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1) + } + + /// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node + /// containing the retained leaf. + /// + /// This has the effect of deleting the the node at the `removed_leaf` index from the tree, + /// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`. + pub fn replace_subtree_with_leaf( + &mut self, + removed_leaf: NodeIndex, + retained_leaf: NodeIndex, + new_depth: u8, + node: RpoDigest, + ) -> RpoDigest { + debug_assert!(!is_empty_root(&node)); + debug_assert!(self.is_leaf(&removed_leaf)); + debug_assert!(self.is_leaf(&retained_leaf)); + debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth()); + debug_assert!(removed_leaf.depth() > new_depth); + + // clear leaf flags + if removed_leaf.depth() == MAX_DEPTH { + self.bottom_leaves.remove(&removed_leaf.value()); + self.bottom_leaves.remove(&retained_leaf.value()); + } else { + self.upper_leaves.remove(&removed_leaf); + self.upper_leaves.remove(&retained_leaf); + } + + // remove the branches leading up to the tier to which the retained leaf is to be moved + self.remove_branch(removed_leaf, new_depth); + self.remove_branch(retained_leaf, new_depth); + + // compute the index of the common root for retained and removed leaves + let mut new_index = retained_leaf; + new_index.move_up_to(new_depth); + debug_assert!(is_leaf_node(&new_index)); + + // insert the node at the root index + self.insert_leaf_node(new_index, node) + } + + /// Inserts the specified node at the specified index; recomputes and returns the new root + /// of the Tiered Sparse Merkle tree. + /// + /// This method assumes that node is a non-empty value. + pub fn insert_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { + debug_assert!(is_leaf_node(&index)); + debug_assert!(!is_empty_root(&node)); + + // mark the node as the leaf + if index.depth() == MAX_DEPTH { + self.bottom_leaves.insert(index.value()); + } else { + self.upper_leaves.insert(index); + }; + + // insert the node and update the path from the node to the root + for _ in 0..index.depth() { + self.nodes.insert(index, node); + let sibling = self.get_node_unchecked(&index.sibling()); + node = Rpo256::merge(&index.build_node(node, sibling)); + index.move_up(); + } + + // update the root + self.nodes.insert(NodeIndex::root(), node); + node + } + + /// Updates the node at the specified index with the specified node value; recomputes and + /// returns the new root of the Tiered Sparse Merkle tree. + /// + /// This method can accept `node` as either an empty or a non-empty value. + pub fn update_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { + debug_assert!(self.is_leaf(&index)); + + // if the value we are updating the node to is a root of an empty tree, clear the leaf + // flag for this node + if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] { + if index.depth() == MAX_DEPTH { + self.bottom_leaves.remove(&index.value()); + } else { + self.upper_leaves.remove(&index); + } + } else { + debug_assert!(!is_empty_root(&node)); + } + + // update the path from the node to the root + for _ in 0..index.depth() { + if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] { + self.nodes.remove(&index); + } else { + self.nodes.insert(index, node); + } + + let sibling = self.get_node_unchecked(&index.sibling()); + node = Rpo256::merge(&index.build_node(node, sibling)); + index.move_up(); + } + + // update the root + self.nodes.insert(NodeIndex::root(), node); + node + } + + /// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes + /// and returns the new root of the Tiered Sparse Merkle tree. + pub fn clear_leaf_node(&mut self, index: NodeIndex) -> RpoDigest { + debug_assert!(self.is_leaf(&index)); + let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize]; + self.update_leaf_node(index, node) + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Returns true if the node at the specified index is a leaf node. + fn is_leaf(&self, index: &NodeIndex) -> bool { + debug_assert!(is_leaf_node(index)); + if index.depth() == MAX_DEPTH { + self.bottom_leaves.contains(&index.value()) + } else { + self.upper_leaves.contains(index) + } + } + + /// Checks if the specified index is valid in the context of this Merkle tree. + /// + /// # Errors + /// Returns an error if: + /// - The specified index depth is 0 or greater than 64. + /// - The node for the specified index does not exists in the Merkle tree. This is possible + /// when an ancestors of the specified index is a leaf node. + fn validate_node_access(&self, index: NodeIndex) -> Result<(), MerkleError> { + if index.is_root() { + return Err(MerkleError::DepthTooSmall(index.depth())); + } else if index.depth() > MAX_DEPTH { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } else { + // make sure that there are no leaf nodes in the ancestors of the index; since leaf + // nodes can live at specific depth, we just need to check these depths. + let tier = get_index_tier(&index); + let mut tier_index = index; + for &depth in TIER_DEPTHS[..tier].iter().rev() { + tier_index.move_up_to(depth); + if self.upper_leaves.contains(&tier_index) { + return Err(MerkleError::NodeNotInSet(index)); + } + } + } + + Ok(()) + } + + /// Returns a node at the specified index. If the node does not exist at this index, a root + /// for an empty subtree at the index's depth is returned. + /// + /// Unlike [NodeStore::get_node()] this does not perform any checks to verify that the + /// returned node is valid in the context of this tree. + fn get_node_unchecked(&self, index: &NodeIndex) -> RpoDigest { + match self.nodes.get(index) { + Some(node) => *node, + None => EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize], + } + } + + /// Removes a sequence of nodes starting at the specified index and traversing the + /// tree up to the specified depth. + /// + /// This method does not update any other nodes and does not recompute the tree root. + fn remove_branch(&mut self, mut index: NodeIndex, end_depth: u8) { + assert!(index.depth() > end_depth); + for _ in 0..(index.depth() - end_depth) { + self.nodes.remove(&index); + index.move_up() + } + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Returns true if the specified node is a root of an empty tree or an empty value ([ZERO; 4]). +fn is_empty_root(node: &RpoDigest) -> bool { + EmptySubtreeRoots::empty_hashes(MAX_DEPTH).contains(node) +} diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index d8e6723..845e76e 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -3,6 +3,9 @@ use super::{ EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word, }; +// INSERTION TESTS +// ================================================================================================ + #[test] fn tsmt_insert_one() { let mut smt = TieredSmt::default(); @@ -216,6 +219,9 @@ fn tsmt_insert_three() { actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); } +// UPDATE TESTS +// ================================================================================================ + #[test] fn tsmt_update() { let mut smt = TieredSmt::default(); @@ -251,6 +257,209 @@ fn tsmt_update() { actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); } +// DELETION TESTS +// ================================================================================================ + +#[test] +fn tsmt_delete_16() { + let mut smt = TieredSmt::default(); + + // --- insert a value into the tree --------------------------------------- + let smt0 = smt.clone(); + let raw_a = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // --- insert another value into the tree --------------------------------- + let smt1 = smt.clone(); + let raw_b = 0b_01011111_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + // --- delete the last inserted value ------------------------------------- + assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt, smt1); + + // --- delete the first inserted value ------------------------------------ + assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt, smt0); +} + +#[test] +fn tsmt_delete_32() { + let mut smt = TieredSmt::default(); + + // --- insert a value into the tree --------------------------------------- + let smt0 = smt.clone(); + let raw_a = 0b_01010101_01101100_01111111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // --- insert another with the same 16-bit prefix into the tree ----------- + let smt1 = smt.clone(); + let raw_b = 0b_01010101_01101100_00111111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + // --- insert the 3rd value with the same 16-bit prefix into the tree ----- + let smt2 = smt.clone(); + let raw_c = 0b_01010101_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_c, value_c); + + // --- delete the last inserted value ------------------------------------- + assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt, smt2); + + // --- delete the last inserted value ------------------------------------- + assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt, smt1); + + // --- delete the first inserted value ------------------------------------ + assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt, smt0); +} + +#[test] +fn tsmt_delete_48_same_32_bit_prefix() { + let mut smt = TieredSmt::default(); + + // test the case when all values share the same 32-bit prefix + + // --- insert a value into the tree --------------------------------------- + let smt0 = smt.clone(); + let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // --- insert another with the same 32-bit prefix into the tree ----------- + let smt1 = smt.clone(); + let raw_b = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + // --- insert the 3rd value with the same 32-bit prefix into the tree ----- + let smt2 = smt.clone(); + let raw_c = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_c, value_c); + + // --- delete the last inserted value ------------------------------------- + assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt, smt2); + + // --- delete the last inserted value ------------------------------------- + assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt, smt1); + + // --- delete the first inserted value ------------------------------------ + assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt, smt0); +} + +#[test] +fn tsmt_delete_48_mixed_prefix() { + let mut smt = TieredSmt::default(); + + // test the case when some values share a 32-bit prefix and others share a 16-bit prefix + + // --- insert a value into the tree --------------------------------------- + let smt0 = smt.clone(); + let raw_a = 0b_01010101_01010101_11111111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // --- insert another with the same 16-bit prefix into the tree ----------- + let smt1 = smt.clone(); + let raw_b = 0b_01010101_01010101_01111111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + // --- insert a value with the same 32-bit prefix as the first value ----- + let smt2 = smt.clone(); + let raw_c = 0b_01010101_01010101_11111111_11111111_11010110_10010011_11100000_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_c, value_c); + + // --- insert another value with the same 32-bit prefix as the first value + let smt3 = smt.clone(); + let raw_d = 0b_01010101_01010101_11111111_11111111_11110110_10010011_11100000_00000000_u64; + let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_d)]); + let value_d = [ONE, ZERO, ZERO, ZERO]; + smt.insert(key_d, value_d); + + // --- delete the inserted values one-by-one ------------------------------ + assert_eq!(smt.insert(key_d, [ZERO; 4]), value_d); + assert_eq!(smt, smt3); + + assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt, smt2); + + assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt, smt1); + + assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt, smt0); +} + +#[test] +fn tsmt_delete_64() { + let mut smt = TieredSmt::default(); + + // test the case when all values share the same 48-bit prefix + + // --- insert a value into the tree --------------------------------------- + let smt0 = smt.clone(); + let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // --- insert a value with the same 48-bit prefix into the tree ----------- + let smt1 = smt.clone(); + let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + // --- insert a value with the same 32-bit prefix into the tree ----------- + let smt2 = smt.clone(); + let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_c, value_c); + + let smt3 = smt.clone(); + let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64; + let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]); + let value_d = [ONE, ZERO, ZERO, ZERO]; + smt.insert(key_d, value_d); + + // --- delete the last inserted value ------------------------------------- + assert_eq!(smt.insert(key_d, [ZERO; 4]), value_d); + assert_eq!(smt, smt3); + + assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt, smt2); + + assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt, smt1); + + assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt, smt0); +} + // BOTTOM TIER TESTS // ================================================================================================ diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs new file mode 100644 index 0000000..b80aebd --- /dev/null +++ b/src/merkle/tiered_smt/values.rs @@ -0,0 +1,580 @@ +use super::{get_key_prefix, is_leaf_node, BTreeMap, NodeIndex, RpoDigest, StarkField, Vec, Word}; +use crate::utils::vec; +use core::{ + cmp::{Ord, Ordering}, + ops::RangeBounds, +}; +use winter_utils::collections::btree_map::Entry; + +// CONSTANTS +// ================================================================================================ + +/// Depths at which leaves can exist in a tiered SMT. +const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS; + +/// Maximum node depth. This is also the bottom tier of the tree. +const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; + +// VALUE STORE +// ================================================================================================ +/// A store for key-value pairs for a Tiered Sparse Merkle tree. +/// +/// The store is organized in a [BTreeMap] where keys are 64 most significant bits of a key, and +/// the values are the corresponding key-value pairs (or a list of key-value pairs if more that +/// a single key-value pair shares the same 64-bit prefix). +/// +/// The store supports lookup by the full key as well as by the 64-bit key prefix. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct ValueStore { + values: BTreeMap, +} + +impl ValueStore { + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns a reference to the value stored under the specified key, or None if there is no + /// value associated with the specified key. + pub fn get(&self, key: &RpoDigest) -> Option<&Word> { + let prefix = get_key_prefix(key); + self.values.get(&prefix).and_then(|entry| entry.get(key)) + } + + /// Returns the first key-value pair such that the key prefix is greater than or equal to the + /// specified prefix. + pub fn get_first(&self, prefix: u64) -> Option<&(RpoDigest, Word)> { + self.range(prefix..).next() + } + + /// Returns the first key-value pair such that the key prefix is greater than or equal to the + /// specified prefix and the key value is not equal to the exclude_key value. + pub fn get_first_filtered( + &self, + prefix: u64, + exclude_key: &RpoDigest, + ) -> Option<&(RpoDigest, Word)> { + self.range(prefix..).find(|(key, _)| key != exclude_key) + } + + /// Returns a vector with key-value pairs for all keys with the specified 64-bit prefix, or + /// None if no keys with the specified prefix are present in this store. + pub fn get_all(&self, prefix: u64) -> Option> { + self.values.get(&prefix).map(|entry| match entry { + StoreEntry::Single(kv_pair) => vec![*kv_pair], + StoreEntry::List(kv_pairs) => kv_pairs.clone(), + }) + } + + /// Returns information about a sibling of a leaf node with the specified index, but only if + /// this is the only sibling the leaf has in some subtree starting at the first tier. + /// + /// For example, if `index` is an index at depth 32, and there is a leaf node at depth 32 with + /// the same root at depth 16 as `index`, we say that this leaf is a lone sibling. + /// + /// The returned tuple contains: they key-value pair of the sibling as well as the index of + /// the node for the root of the common subtree in which both nodes are leaves. + /// + /// This method assumes that the key-value pair for the specified index has already been + /// removed from the store. + pub fn get_lone_sibling(&self, index: NodeIndex) -> Option<(&RpoDigest, &Word, NodeIndex)> { + debug_assert!(is_leaf_node(&index)); + + // iterate over tiers from top to bottom, looking at the tiers which are strictly above + // the depth of the index. This implies that only tiers at depth 32 and 48 will be + // considered. For each tier, check if the parent of the index at the higher tier + // contains a single node. + for &tier in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) { + // compute the index of the root at a higher tier + let mut parent_index = index; + parent_index.move_up_to(tier); + + // find the lone sibling, if any; we need to handle the "last node" at a given tier + // separately specify the bounds for the search correctly. + let start_prefix = parent_index.value() << (MAX_DEPTH - tier); + let sibling = if start_prefix.leading_ones() as u8 == tier { + let mut iter = self.range(start_prefix..); + iter.next().filter(|_| iter.next().is_none()) + } else { + let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier); + let mut iter = self.range(start_prefix..end_prefix); + iter.next().filter(|_| iter.next().is_none()) + }; + + if let Some((key, value)) = sibling { + return Some((key, value, parent_index)); + } + } + + None + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Inserts the specified key-value pair into this store and returns the value previously + /// associated with the specified key. + /// + /// If no value was previously associated with the specified key, None is returned. + pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option { + let prefix = get_key_prefix(&key); + match self.values.entry(prefix) { + Entry::Occupied(mut entry) => entry.get_mut().insert(key, value), + Entry::Vacant(entry) => { + entry.insert(StoreEntry::new(key, value)); + None + } + } + } + + /// Removes the key-value pair for the specified key from this store and returns the value + /// associated with this key. + /// + /// If no value was associated with the specified key, None is returned. + pub fn remove(&mut self, key: &RpoDigest) -> Option { + let prefix = get_key_prefix(key); + match self.values.entry(prefix) { + Entry::Occupied(mut entry) => { + let (value, remove_entry) = entry.get_mut().remove(key); + if remove_entry { + entry.remove_entry(); + } + value + } + Entry::Vacant(_) => None, + } + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Returns an iterator over all key-value pairs contained in this store such that the most + /// significant 64 bits of the key lay within the specified bounds. + /// + /// The order of iteration is from the smallest to the largest key. + fn range>(&self, bounds: R) -> impl Iterator { + self.values.range(bounds).flat_map(|(_, entry)| entry.iter()) + } +} + +// VALUE NODE +// ================================================================================================ + +/// An entry in the [ValueStore]. +/// +/// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by +/// key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StoreEntry { + Single((RpoDigest, Word)), + List(Vec<(RpoDigest, Word)>), +} + +impl StoreEntry { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Returns a new [StoreEntry] instantiated with a single key-value pair. + pub fn new(key: RpoDigest, value: Word) -> Self { + Self::Single((key, value)) + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the value associated with the specified key, or None if this entry does not contain + /// a value associated with the specified key. + pub fn get(&self, key: &RpoDigest) -> Option<&Word> { + match self { + StoreEntry::Single(kv_pair) => { + if kv_pair.0 == *key { + Some(&kv_pair.1) + } else { + None + } + } + StoreEntry::List(kv_pairs) => { + match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) { + Ok(pos) => Some(&kv_pairs[pos].1), + Err(_) => None, + } + } + } + } + + /// Returns an iterator over all key-value pairs in this entry. + pub fn iter(&self) -> impl Iterator { + EntryIterator { + entry: self, + pos: 0, + } + } + + // STATE MUTATORS + // -------------------------------------------------------------------------------------------- + + /// Inserts the specified key-value pair into this entry and returns the value previously + /// associated with the specified key, or None if no value was associated with the specified + /// key. + /// + /// If a new key is inserted, this will also transform a `SingleEntry` into a `ListEntry`. + pub fn insert(&mut self, key: RpoDigest, value: Word) -> Option { + match self { + StoreEntry::Single(kv_pair) => { + // if the key is already in this entry, update the value and return + if kv_pair.0 == key { + let old_value = kv_pair.1; + kv_pair.1 = value; + return Some(old_value); + } + + // transform the entry into a list entry, and make sure the key-value pairs + // are sorted by key + let mut pairs = vec![*kv_pair, (key, value)]; + pairs.sort_by(|a, b| cmp_digests(&a.0, &b.0)); + + *self = StoreEntry::List(pairs); + None + } + StoreEntry::List(pairs) => { + match pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, &key)) { + Ok(pos) => { + let old_value = pairs[pos].1; + pairs[pos].1 = value; + Some(old_value) + } + Err(pos) => { + pairs.insert(pos, (key, value)); + None + } + } + } + } + } + + /// Removes the key-value pair with the specified key from this entry, and returns the value + /// of the removed pair. If the entry did not contain a key-value pair for the specified key, + /// None is returned. + /// + /// If the last last key-value pair was removed from the entry, the second tuple value will + /// be set to true. + pub fn remove(&mut self, key: &RpoDigest) -> (Option, bool) { + match self { + StoreEntry::Single(kv_pair) => { + if kv_pair.0 == *key { + (Some(kv_pair.1), true) + } else { + (None, false) + } + } + StoreEntry::List(kv_pairs) => { + match kv_pairs.binary_search_by(|kv_pair| cmp_digests(&kv_pair.0, key)) { + Ok(pos) => { + let kv_pair = kv_pairs.remove(pos); + if kv_pairs.len() == 1 { + *self = StoreEntry::Single(kv_pairs[0]); + } + (Some(kv_pair.1), false) + } + Err(_) => (None, false), + } + } + } + } +} + +/// A custom iterator over key-value pairs of a [StoreEntry]. +/// +/// For a `SingleEntry` this returns only one value, but for `ListEntry`, this iterates over the +/// entire list of key-value pairs. +pub struct EntryIterator<'a> { + entry: &'a StoreEntry, + pos: usize, +} + +impl<'a> Iterator for EntryIterator<'a> { + type Item = &'a (RpoDigest, Word); + + fn next(&mut self) -> Option { + match self.entry { + StoreEntry::Single(kv_pair) => { + if self.pos == 0 { + self.pos = 1; + Some(kv_pair) + } else { + None + } + } + StoreEntry::List(kv_pairs) => { + if self.pos >= kv_pairs.len() { + None + } else { + let kv_pair = &kv_pairs[self.pos]; + self.pos += 1; + Some(kv_pair) + } + } + } + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Compares two digests element-by-element using their integer representations starting with the +/// most significant element. +fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering { + let d1 = Word::from(d1); + let d2 = Word::from(d2); + + for (v1, v2) in d1.iter().zip(d2.iter()).rev() { + let v1 = v1.as_int(); + let v2 = v2.as_int(); + if v1 != v2 { + return v1.cmp(&v2); + } + } + + Ordering::Equal +} + +// TESTS +// ================================================================================================ + +#[cfg(test)] +mod tests { + + use super::{RpoDigest, ValueStore}; + use crate::{ + merkle::{tiered_smt::values::StoreEntry, NodeIndex}, + Felt, ONE, WORD_SIZE, ZERO, + }; + + #[test] + fn test_insert() { + let mut store = ValueStore::default(); + + // insert the first key-value pair into the store + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE; WORD_SIZE]; + + assert!(store.insert(key_a, value_a).is_none()); + assert_eq!(store.values.len(), 1); + + let entry = store.values.get(&raw_a).unwrap(); + let expected_entry = StoreEntry::Single((key_a, value_a)); + assert_eq!(entry, &expected_entry); + + // insert a key-value pair with a different key into the store; since the keys are + // different, another entry is added to the values map + let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ZERO, ONE, ZERO]; + + assert!(store.insert(key_b, value_b).is_none()); + assert_eq!(store.values.len(), 2); + + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = StoreEntry::Single((key_a, value_a)); + assert_eq!(entry1, &expected_entry1); + + let entry2 = store.values.get(&raw_b).unwrap(); + let expected_entry2 = StoreEntry::Single((key_b, value_b)); + assert_eq!(entry2, &expected_entry2); + + // insert a key-value pair with the same 64-bit key prefix as the first key; this should + // transform the first entry into a List entry + let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + + assert!(store.insert(key_c, value_c).is_none()); + assert_eq!(store.values.len(), 2); + + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a)]); + assert_eq!(entry1, &expected_entry1); + + let entry2 = store.values.get(&raw_b).unwrap(); + let expected_entry2 = StoreEntry::Single((key_b, value_b)); + assert_eq!(entry2, &expected_entry2); + + // replace values for keys a and b + let value_a2 = [ONE, ONE, ONE, ZERO]; + let value_b2 = [ZERO, ZERO, ZERO, ONE]; + + assert_eq!(store.insert(key_a, value_a2), Some(value_a)); + assert_eq!(store.values.len(), 2); + + assert_eq!(store.insert(key_b, value_b2), Some(value_b)); + assert_eq!(store.values.len(), 2); + + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2)]); + assert_eq!(entry1, &expected_entry1); + + let entry2 = store.values.get(&raw_b).unwrap(); + let expected_entry2 = StoreEntry::Single((key_b, value_b2)); + assert_eq!(entry2, &expected_entry2); + + // insert one more key-value pair with the same 64-bit key-prefix as the first key + let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_d = [ZERO, ONE, ZERO, ZERO]; + + assert!(store.insert(key_d, value_d).is_none()); + assert_eq!(store.values.len(), 2); + + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = + StoreEntry::List(vec![(key_c, value_c), (key_a, value_a2), (key_d, value_d)]); + assert_eq!(entry1, &expected_entry1); + + let entry2 = store.values.get(&raw_b).unwrap(); + let expected_entry2 = StoreEntry::Single((key_b, value_b2)); + assert_eq!(entry2, &expected_entry2); + } + + #[test] + fn test_remove() { + // populate the value store + let mut store = ValueStore::default(); + + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE; WORD_SIZE]; + store.insert(key_a, value_a); + + let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ZERO, ONE, ZERO]; + store.insert(key_b, value_b); + + let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + store.insert(key_c, value_c); + + let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_d = [ZERO, ONE, ZERO, ZERO]; + store.insert(key_d, value_d); + + assert_eq!(store.values.len(), 2); + + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = + StoreEntry::List(vec![(key_c, value_c), (key_a, value_a), (key_d, value_d)]); + assert_eq!(entry1, &expected_entry1); + + let entry2 = store.values.get(&raw_b).unwrap(); + let expected_entry2 = StoreEntry::Single((key_b, value_b)); + assert_eq!(entry2, &expected_entry2); + + // remove non-existent keys + let key_e = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_a)]); + assert!(store.remove(&key_e).is_none()); + + let raw_f = 0b_11111110_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_f = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_f)]); + assert!(store.remove(&key_f).is_none()); + + // remove keys from the list entry + assert_eq!(store.remove(&key_c).unwrap(), value_c); + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = StoreEntry::List(vec![(key_a, value_a), (key_d, value_d)]); + assert_eq!(entry1, &expected_entry1); + + assert_eq!(store.remove(&key_a).unwrap(), value_a); + let entry1 = store.values.get(&raw_a).unwrap(); + let expected_entry1 = StoreEntry::Single((key_d, value_d)); + assert_eq!(entry1, &expected_entry1); + + assert_eq!(store.remove(&key_d).unwrap(), value_d); + assert!(store.values.get(&raw_a).is_none()); + assert_eq!(store.values.len(), 1); + + // remove a key from a single entry + assert_eq!(store.remove(&key_b).unwrap(), value_b); + assert!(store.values.get(&raw_b).is_none()); + assert_eq!(store.values.len(), 0); + } + + #[test] + fn test_range() { + // populate the value store + let mut store = ValueStore::default(); + + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE; WORD_SIZE]; + store.insert(key_a, value_a); + + let raw_b = 0b_11111110_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ZERO, ONE, ZERO]; + store.insert(key_b, value_b); + + let key_c = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + store.insert(key_c, value_c); + + let key_d = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_d = [ZERO, ONE, ZERO, ZERO]; + store.insert(key_d, value_d); + + let raw_e = 0b_10101000_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_e = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_e)]); + let value_e = [ZERO, ZERO, ZERO, ONE]; + store.insert(key_e, value_e); + + // check the entire range + let mut iter = store.range(..u64::MAX); + assert_eq!(iter.next(), Some(&(key_e, value_e))); + assert_eq!(iter.next(), Some(&(key_c, value_c))); + assert_eq!(iter.next(), Some(&(key_a, value_a))); + assert_eq!(iter.next(), Some(&(key_d, value_d))); + assert_eq!(iter.next(), Some(&(key_b, value_b))); + assert_eq!(iter.next(), None); + + // check all but e + let mut iter = store.range(raw_a..u64::MAX); + assert_eq!(iter.next(), Some(&(key_c, value_c))); + assert_eq!(iter.next(), Some(&(key_a, value_a))); + assert_eq!(iter.next(), Some(&(key_d, value_d))); + assert_eq!(iter.next(), Some(&(key_b, value_b))); + assert_eq!(iter.next(), None); + + // check all but e and b + let mut iter = store.range(raw_a..raw_b); + assert_eq!(iter.next(), Some(&(key_c, value_c))); + assert_eq!(iter.next(), Some(&(key_a, value_a))); + assert_eq!(iter.next(), Some(&(key_d, value_d))); + assert_eq!(iter.next(), None); + } + + #[test] + fn test_get_lone_sibling() { + // populate the value store + let mut store = ValueStore::default(); + + let raw_a = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_a = RpoDigest::from([ZERO, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE; WORD_SIZE]; + store.insert(key_a, value_a); + + let raw_b = 0b_11111111_11111111_00011111_11111111_10010110_10010011_11100000_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ZERO, ONE, ZERO]; + store.insert(key_b, value_b); + + // check sibling node for `a` + let index = NodeIndex::make(32, 0b_10101010_10101010_00011111_11111110); + let parent_index = NodeIndex::make(16, 0b_10101010_10101010); + assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index))); + + // check sibling node for `b` + let index = NodeIndex::make(32, 0b_11111111_11111111_00011111_11111111); + let parent_index = NodeIndex::make(16, 0b_11111111_11111111); + assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index))); + + // check some other sibling for some other index + let index = NodeIndex::make(32, 0b_11101010_10101010); + assert_eq!(store.get_lone_sibling(index), None); + } +} From a03f2b5d5e21df2ea0efdbd91be20aa6a32a060d Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Tue, 1 Aug 2023 11:02:29 -0700 Subject: [PATCH 09/32] feat: implement iterator over key-value pairs for TSMT --- src/merkle/tiered_smt/mod.rs | 6 ++++++ src/merkle/tiered_smt/values.rs | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 52a3f89..b8dd52f 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -214,6 +214,11 @@ impl TieredSmt { // ITERATORS // -------------------------------------------------------------------------------------------- + /// Returns an iterator over all key-value pairs in this [TieredSmt]. + pub fn iter(&self) -> impl Iterator { + self.values.iter() + } + /// Returns an iterator over all inner nodes of this [TieredSmt] (i.e., nodes not at depths 16 /// 32, 48, or 64). /// @@ -230,6 +235,7 @@ impl TieredSmt { self.nodes.upper_leaves().map(|(index, node)| { let key_prefix = index_to_prefix(index); let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found"); + debug_assert_eq!(key_to_index(key, index.depth()), *index); (*node, *key, *value) }) } diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs index b80aebd..eca8e5d 100644 --- a/src/merkle/tiered_smt/values.rs +++ b/src/merkle/tiered_smt/values.rs @@ -108,6 +108,11 @@ impl ValueStore { None } + /// Returns an iterator over all key-value pairs in this store. + pub fn iter(&self) -> impl Iterator { + self.values.iter().flat_map(|(_, entry)| entry.iter()) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- From 6810b5e3ab77fe63e23a2e9a57e6ef92927f291d Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Tue, 1 Aug 2023 14:43:57 -0700 Subject: [PATCH 10/32] fix: node type check in inner_nodes() iterator of TSMT --- src/merkle/partial_mt/mod.rs | 3 +- src/merkle/tiered_smt/mod.rs | 189 +++++++++++++++++++------------- src/merkle/tiered_smt/nodes.rs | 94 +++++++++------- src/merkle/tiered_smt/tests.rs | 19 +++- src/merkle/tiered_smt/values.rs | 44 ++++---- 5 files changed, 208 insertions(+), 141 deletions(-) diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index ef87516..10f7231 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -118,7 +118,7 @@ impl PartialMerkleTree { // fill layers without nodes with empty vector for depth in 0..max_depth { - layers.entry(depth).or_insert(vec![]); + layers.entry(depth).or_default(); } let mut layer_iter = layers.into_values().rev(); @@ -370,7 +370,6 @@ impl PartialMerkleTree { return Ok(old_value); } - let mut node_index = node_index; let mut value = value.into(); for _ in 0..node_index.depth() { let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist"); diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index b8dd52f..5379d20 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -2,7 +2,7 @@ use super::{ BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, StarkField, Vec, Word, }; -use core::cmp; +use core::{cmp, ops::Deref}; mod nodes; use nodes::NodeStore; @@ -148,32 +148,36 @@ impl TieredSmt { return self.remove_leaf_node(key); } - // insert the value into the value store, and if nothing has changed, return - let (old_value, is_update) = match self.values.insert(key, value) { - Some(old_value) => { - if old_value == value { - return old_value; - } - (old_value, true) + // insert the value into the value store, and if the key was already in the store, update + // it with the new value + if let Some(old_value) = self.values.insert(key, value) { + if old_value != value { + // if the new value is different from the old value, determine the location of + // the leaf node for this key, build the node, and update the root + let (index, leaf_exists) = self.nodes.get_leaf_index(&key); + debug_assert!(leaf_exists); + let node = self.build_leaf_node(index, key, value); + self.root = self.nodes.update_leaf_node(index, node); } - None => (Self::EMPTY_VALUE, false), + return old_value; }; - // determine the index for the value node; this index could have 3 different meanings: - // - it points to a root of an empty subtree (excluding depth = 64); in this case, we can - // replace the node with the value node immediately. - // - it points to a node at the bottom tier (i.e., depth = 64); in this case, we need to - // process bottom-tier insertion which will be handled by insert_leaf_node(). - // - it points to an existing leaf node; this node could be a node with the same key or a - // different key with a common prefix; in the latter case, we'll need to move the leaf - // to a lower tier - let (index, leaf_exists) = self.nodes.get_insert_location(&key); - debug_assert!(!is_update || leaf_exists); + // determine the location for the leaf node; this index could have 3 different meanings: + // - it points to a root of an empty subtree or an empty node at depth 64; in this case, + // we can replace the node with the value node immediately. + // - it points to an existing leaf at the bottom tier (i.e., depth = 64); in this case, + // we need to process update the bottom leaf. + // - it points to an existing leaf node for a different key with the same prefix (same + // key case was handled above); in this case, we need to move the leaf to a lower tier + let (index, leaf_exists) = self.nodes.get_leaf_index(&key); + + self.root = if leaf_exists && index.depth() == Self::MAX_DEPTH { + // returned index points to a leaf at the bottom tier + let node = self.build_leaf_node(index, key, value); + self.nodes.update_leaf_node(index, node) + } else if leaf_exists { + // returned index pointes to a leaf for a different key with the same prefix - // if the returned index points to a leaf, and this leaf is for a different key (i.e., we - // are not updating a value for an existing key), we need to replace this leaf with a tree - // containing leaves for both the old and the new key-value pairs - if leaf_exists && !is_update { // get the key-value pair for the key with the same prefix; since the key-value // pair has already been inserted into the value store, we need to filter it out // when looking for the other key-value pair @@ -183,12 +187,12 @@ impl TieredSmt { .expect("other key-value pair not found"); // determine how far down the tree should we move the leaves - let common_prefix_len = get_common_prefix_tier(&key, other_key); + let common_prefix_len = get_common_prefix_tier_depth(&key, other_key); let depth = cmp::min(common_prefix_len + Self::TIER_SIZE, Self::MAX_DEPTH); // compute node locations for new and existing key-value paris - let new_index = key_to_index(&key, depth); - let other_index = key_to_index(other_key, depth); + let new_index = LeafNodeIndex::from_key(&key, depth); + let other_index = LeafNodeIndex::from_key(other_key, depth); // compute node values for the new and existing key-value pairs let new_node = self.build_leaf_node(new_index, key, value); @@ -196,19 +200,17 @@ impl TieredSmt { // replace the leaf located at index with a subtree containing nodes for new and // existing key-value paris - self.root = self.nodes.replace_leaf_with_subtree( + self.nodes.replace_leaf_with_subtree( index, [(new_index, new_node), (other_index, other_node)], - ); + ) } else { - // if the returned index points to an empty subtree, or a leaf with the same key (i.e., - // we are performing an update), or a leaf is at the bottom tier, compute its node - // value and do a simple insert + // returned index points to an empty subtree or an empty leaf at the bottom tier let node = self.build_leaf_node(index, key, value); - self.root = self.nodes.insert_leaf_node(index, node); - } + self.nodes.insert_leaf_node(index, node) + }; - old_value + Self::EMPTY_VALUE } // ITERATORS @@ -235,7 +237,7 @@ impl TieredSmt { self.nodes.upper_leaves().map(|(index, node)| { let key_prefix = index_to_prefix(index); let (key, value) = self.values.get_first(key_prefix).expect("upper leaf not found"); - debug_assert_eq!(key_to_index(key, index.depth()), *index); + debug_assert_eq!(*index, LeafNodeIndex::from_key(key, index.depth()).into()); (*node, *key, *value) }) } @@ -269,8 +271,8 @@ impl TieredSmt { }; // determine the location of the leaf holding the key-value pair to be removed - let (index, leaf_exists) = self.nodes.get_insert_location(&key); - debug_assert!(index.depth() == Self::MAX_DEPTH || leaf_exists); + let (index, leaf_exists) = self.nodes.get_leaf_index(&key); + debug_assert!(leaf_exists); // if the leaf is at the bottom tier and after removing the key-value pair from it, the // leaf is still not empty, just recompute its hash and update the leaf node. @@ -286,7 +288,7 @@ impl TieredSmt { // higher tier, we need to move the sibling to a higher tier if let Some((sib_key, sib_val, new_sib_index)) = self.values.get_lone_sibling(index) { // determine the current index of the sibling node - let sib_index = key_to_index(sib_key, index.depth()); + let sib_index = LeafNodeIndex::from_key(sib_key, index.depth()); debug_assert!(sib_index.depth() > new_sib_index.depth()); // compute node value for the new location of the sibling leaf and replace the subtree @@ -309,9 +311,8 @@ impl TieredSmt { /// the value store, however, for depths 16, 32, and 48, the node is computed directly from /// the passed-in values (for depth 64, the value store is queried to get all the key-value /// pairs located at the specified index). - fn build_leaf_node(&self, index: NodeIndex, key: RpoDigest, value: Word) -> RpoDigest { + fn build_leaf_node(&self, index: LeafNodeIndex, key: RpoDigest, value: Word) -> RpoDigest { let depth = index.depth(); - debug_assert!(Self::TIER_DEPTHS.contains(&depth)); // insert the key into index-key map and compute the new value of the node if index.depth() == Self::MAX_DEPTH { @@ -337,6 +338,71 @@ impl Default for TieredSmt { } } +// LEAF NODE INDEX +// ================================================================================================ +/// A wrapper around [NodeIndex] to provide type-safe references to nodes at depths 16, 32, 48, and +/// 64. +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct LeafNodeIndex(NodeIndex); + +impl LeafNodeIndex { + /// Returns a new [LeafNodeIndex] instantiated from the provided [NodeIndex]. + /// + /// In debug mode, panics if index depth is not 16, 32, 48, or 64. + pub fn new(index: NodeIndex) -> Self { + // check if the depth is 16, 32, 48, or 64; this works because for a valid depth, + // depth - 16, can be 0, 16, 32, or 48 - i.e., the value is either 0 or any of the 4th + // or 5th bits are set. We can test for this by computing a bitwise AND with a value + // which has all but the 4th and 5th bits set (which is !48). + debug_assert_eq!(((index.depth() - 16) & !48), 0, "invalid tier depth {}", index.depth()); + Self(index) + } + + /// Returns a new [LeafNodeIndex] instantiated from the specified key inserted at the specified + /// depth. + /// + /// The value for the key is computed by taking n most significant bits from the most significant + /// element of the key, where n is the specified depth. + pub fn from_key(key: &RpoDigest, depth: u8) -> Self { + let mse = get_key_prefix(key); + Self::new(NodeIndex::new_unchecked(depth, mse >> (TieredSmt::MAX_DEPTH - depth))) + } + + /// Returns a new [LeafNodeIndex] instantiated for testing purposes. + #[cfg(test)] + pub fn make(depth: u8, value: u64) -> Self { + Self::new(NodeIndex::make(depth, value)) + } + + /// Traverses towards the root until the specified depth is reached. + /// + /// The new depth must be a valid tier depth - i.e., 16, 32, 48, or 64. + pub fn move_up_to(&mut self, depth: u8) { + debug_assert_eq!(((depth - 16) & !48), 0, "invalid tier depth: {depth}"); + self.0.move_up_to(depth); + } +} + +impl Deref for LeafNodeIndex { + type Target = NodeIndex; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for LeafNodeIndex { + fn from(value: NodeIndex) -> Self { + Self::new(value) + } +} + +impl From for NodeIndex { + fn from(value: LeafNodeIndex) -> Self { + value.0 + } +} + // HELPER FUNCTIONS // ================================================================================================ @@ -351,19 +417,6 @@ fn index_to_prefix(index: &NodeIndex) -> u64 { index.value() << (TieredSmt::MAX_DEPTH - index.depth()) } -/// Returns index for the specified key inserted at the specified depth. -/// -/// The value for the key is computed by taking n most significant bits from the most significant -/// element of the key, where n is the specified depth. -fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { - let mse = get_key_prefix(key); - let value = match depth { - 16 | 32 | 48 | 64 => mse >> ((TieredSmt::MAX_DEPTH - depth) as u32), - _ => unreachable!("invalid depth: {depth}"), - }; - NodeIndex::new_unchecked(depth, value) -} - /// Returns tiered common prefix length between the most significant elements of the provided keys. /// /// Specifically: @@ -372,36 +425,13 @@ fn key_to_index(key: &RpoDigest, depth: u8) -> NodeIndex { /// - returns 32 if the common prefix is between 32 and 47 bits. /// - returns 16 if the common prefix is between 16 and 31 bits. /// - returns 0 if the common prefix is fewer than 16 bits. -fn get_common_prefix_tier(key1: &RpoDigest, key2: &RpoDigest) -> u8 { +fn get_common_prefix_tier_depth(key1: &RpoDigest, key2: &RpoDigest) -> u8 { let e1 = get_key_prefix(key1); let e2 = get_key_prefix(key2); let ex = (e1 ^ e2).leading_zeros() as u8; (ex / 16) * 16 } -/// Returns a tier for the specified index. -/// -/// The tiers are defined as follows: -/// - Tier 0: depth 0 through 16 (inclusive). -/// - Tier 1: depth 17 through 32 (inclusive). -/// - Tier 2: depth 33 through 48 (inclusive). -/// - Tier 3: depth 49 through 64 (inclusive). -const fn get_index_tier(index: &NodeIndex) -> usize { - debug_assert!(index.depth() <= TieredSmt::MAX_DEPTH, "invalid depth"); - match index.depth() { - 0..=16 => 0, - 17..=32 => 1, - 33..=48 => 2, - _ => 3, - } -} - -/// Returns true if the specified index is an index for an leaf node (i.e., the depth is 16, 32, -/// 48, or 64). -const fn is_leaf_node(index: &NodeIndex) -> bool { - matches!(index.depth(), 16 | 32 | 48 | 64) -} - /// Computes node value for leaves at tiers 16, 32, or 48. /// /// Node value is computed as: hash(key || value, domain = depth). @@ -413,7 +443,10 @@ pub fn hash_upper_leaf(key: RpoDigest, value: Word, depth: u8) -> RpoDigest { /// Computes node value for leaves at the bottom tier (depth 64). /// -/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n, domain=64]). +/// Node value is computed as: hash([key_0, value_0, ..., key_n, value_n], domain=64). +/// +/// TODO: when hashing in domain is implemented for `hash_elements()`, combine this function with +/// `hash_upper_leaf()` function. pub fn hash_bottom_leaf(values: &[(RpoDigest, Word)]) -> RpoDigest { let mut elements = Vec::with_capacity(values.len() * 8); for (key, val) in values.iter() { diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs index 42bad5e..7135c6c 100644 --- a/src/merkle/tiered_smt/nodes.rs +++ b/src/merkle/tiered_smt/nodes.rs @@ -1,6 +1,6 @@ use super::{ - get_index_tier, get_key_prefix, is_leaf_node, BTreeMap, BTreeSet, EmptySubtreeRoots, - InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, + BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, LeafNodeIndex, MerkleError, MerklePath, + NodeIndex, Rpo256, RpoDigest, Vec, }; // CONSTANTS @@ -21,7 +21,8 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// A store of nodes for a Tiered Sparse Merkle tree. /// /// The store contains information about all nodes as well as information about which of the nodes -/// represent leaf nodes in a Tiered Sparse Merkle tree. +/// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s +/// are used to determine the position of the leaves in the tree. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NodeStore { nodes: BTreeMap, @@ -88,14 +89,13 @@ impl NodeStore { /// Returns an index at which a leaf node for the specified key should be inserted. /// /// The second value in the returned tuple is set to true if the node at the returned index - /// is already a leaf node, excluding leaves at the bottom tier (i.e., if the leaf is at the - /// bottom tier, false is returned). - pub fn get_insert_location(&self, key: &RpoDigest) -> (NodeIndex, bool) { + /// is already a leaf node. + pub fn get_leaf_index(&self, key: &RpoDigest) -> (LeafNodeIndex, bool) { // traverse the tree from the root down checking nodes at tiers 16, 32, and 48. Return if // a node at any of the tiers is either a leaf or a root of an empty subtree. - let mse = get_key_prefix(key); - for depth in (TIER_DEPTHS[0]..MAX_DEPTH).step_by(TIER_SIZE as usize) { - let index = NodeIndex::new_unchecked(depth, mse >> (MAX_DEPTH - depth)); + const NUM_UPPER_TIERS: usize = TIER_DEPTHS.len() - 1; + for &tier_depth in TIER_DEPTHS[..NUM_UPPER_TIERS].iter() { + let index = LeafNodeIndex::from_key(key, tier_depth); if self.upper_leaves.contains(&index) { return (index, true); } else if !self.nodes.contains_key(&index) { @@ -105,8 +105,8 @@ impl NodeStore { // if we got here, that means all of the nodes checked so far are internal nodes, and // the new node would need to be inserted in the bottom tier. - let index = NodeIndex::new_unchecked(MAX_DEPTH, mse); - (index, false) + let index = LeafNodeIndex::from_key(key, MAX_DEPTH); + (index, self.bottom_leaves.contains(&index.value())) } // ITERATORS @@ -118,7 +118,7 @@ impl NodeStore { /// The iterator order is unspecified. pub fn inner_nodes(&self) -> impl Iterator + '_ { self.nodes.iter().filter_map(|(index, node)| { - if !is_leaf_node(index) { + if self.is_internal_node(index) { Some(InnerNodeInfo { value: *node, left: self.get_node_unchecked(&index.left_child()), @@ -152,20 +152,26 @@ impl NodeStore { /// at the specified indexes. Recomputes and returns the new root. pub fn replace_leaf_with_subtree( &mut self, - leaf_index: NodeIndex, - subtree_leaves: [(NodeIndex, RpoDigest); 2], + leaf_index: LeafNodeIndex, + subtree_leaves: [(LeafNodeIndex, RpoDigest); 2], ) -> RpoDigest { - debug_assert!(is_leaf_node(&leaf_index)); - debug_assert!(is_leaf_node(&subtree_leaves[0].0)); - debug_assert!(is_leaf_node(&subtree_leaves[1].0)); + debug_assert!(self.is_non_empty_leaf(&leaf_index)); debug_assert!(!is_empty_root(&subtree_leaves[0].1)); debug_assert!(!is_empty_root(&subtree_leaves[1].1)); debug_assert_eq!(subtree_leaves[0].0.depth(), subtree_leaves[1].0.depth()); debug_assert!(leaf_index.depth() < subtree_leaves[0].0.depth()); self.upper_leaves.remove(&leaf_index); - self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1); - self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1) + + if subtree_leaves[0].0 == subtree_leaves[1].0 { + // if the subtree is for a single node at depth 64, we only need to insert one node + debug_assert_eq!(subtree_leaves[0].0.depth(), MAX_DEPTH); + debug_assert_eq!(subtree_leaves[0].1, subtree_leaves[1].1); + self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1) + } else { + self.insert_leaf_node(subtree_leaves[0].0, subtree_leaves[0].1); + self.insert_leaf_node(subtree_leaves[1].0, subtree_leaves[1].1) + } } /// Replaces a subtree containing the retained and the removed leaf nodes, with a leaf node @@ -175,14 +181,14 @@ impl NodeStore { /// moving the node at the `retained_leaf` index up to the tier specified by `new_depth`. pub fn replace_subtree_with_leaf( &mut self, - removed_leaf: NodeIndex, - retained_leaf: NodeIndex, + removed_leaf: LeafNodeIndex, + retained_leaf: LeafNodeIndex, new_depth: u8, node: RpoDigest, ) -> RpoDigest { debug_assert!(!is_empty_root(&node)); - debug_assert!(self.is_leaf(&removed_leaf)); - debug_assert!(self.is_leaf(&retained_leaf)); + debug_assert!(self.is_non_empty_leaf(&removed_leaf)); + debug_assert!(self.is_non_empty_leaf(&retained_leaf)); debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth()); debug_assert!(removed_leaf.depth() > new_depth); @@ -202,7 +208,6 @@ impl NodeStore { // compute the index of the common root for retained and removed leaves let mut new_index = retained_leaf; new_index.move_up_to(new_depth); - debug_assert!(is_leaf_node(&new_index)); // insert the node at the root index self.insert_leaf_node(new_index, node) @@ -211,19 +216,21 @@ impl NodeStore { /// Inserts the specified node at the specified index; recomputes and returns the new root /// of the Tiered Sparse Merkle tree. /// - /// This method assumes that node is a non-empty value. - pub fn insert_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { - debug_assert!(is_leaf_node(&index)); + /// This method assumes that the provided node is a non-empty value, and that there is no node + /// at the specified index. + pub fn insert_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest { debug_assert!(!is_empty_root(&node)); + debug_assert_eq!(self.nodes.get(&index), None); // mark the node as the leaf if index.depth() == MAX_DEPTH { self.bottom_leaves.insert(index.value()); } else { - self.upper_leaves.insert(index); + self.upper_leaves.insert(index.into()); }; // insert the node and update the path from the node to the root + let mut index: NodeIndex = index.into(); for _ in 0..index.depth() { self.nodes.insert(index, node); let sibling = self.get_node_unchecked(&index.sibling()); @@ -240,8 +247,8 @@ impl NodeStore { /// returns the new root of the Tiered Sparse Merkle tree. /// /// This method can accept `node` as either an empty or a non-empty value. - pub fn update_leaf_node(&mut self, mut index: NodeIndex, mut node: RpoDigest) -> RpoDigest { - debug_assert!(self.is_leaf(&index)); + pub fn update_leaf_node(&mut self, index: LeafNodeIndex, mut node: RpoDigest) -> RpoDigest { + debug_assert!(self.is_non_empty_leaf(&index)); // if the value we are updating the node to is a root of an empty tree, clear the leaf // flag for this node @@ -256,6 +263,7 @@ impl NodeStore { } // update the path from the node to the root + let mut index: NodeIndex = index.into(); for _ in 0..index.depth() { if node == EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize] { self.nodes.remove(&index); @@ -275,8 +283,8 @@ impl NodeStore { /// Replaces the leaf node at the specified index with a root of an empty subtree; recomputes /// and returns the new root of the Tiered Sparse Merkle tree. - pub fn clear_leaf_node(&mut self, index: NodeIndex) -> RpoDigest { - debug_assert!(self.is_leaf(&index)); + pub fn clear_leaf_node(&mut self, index: LeafNodeIndex) -> RpoDigest { + debug_assert!(self.is_non_empty_leaf(&index)); let node = EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[index.depth() as usize]; self.update_leaf_node(index, node) } @@ -285,8 +293,7 @@ impl NodeStore { // -------------------------------------------------------------------------------------------- /// Returns true if the node at the specified index is a leaf node. - fn is_leaf(&self, index: &NodeIndex) -> bool { - debug_assert!(is_leaf_node(index)); + fn is_non_empty_leaf(&self, index: &LeafNodeIndex) -> bool { if index.depth() == MAX_DEPTH { self.bottom_leaves.contains(&index.value()) } else { @@ -294,6 +301,16 @@ impl NodeStore { } } + /// Returns true if the node at the specified index is an internal node - i.e., there is + /// no leaf at that node and the node does not belong to the bottom tier. + fn is_internal_node(&self, index: &NodeIndex) -> bool { + if index.depth() == MAX_DEPTH { + false + } else { + !self.upper_leaves.contains(index) + } + } + /// Checks if the specified index is valid in the context of this Merkle tree. /// /// # Errors @@ -309,7 +326,7 @@ impl NodeStore { } else { // make sure that there are no leaf nodes in the ancestors of the index; since leaf // nodes can live at specific depth, we just need to check these depths. - let tier = get_index_tier(&index); + let tier = ((index.depth() - 1) / TIER_SIZE) as usize; let mut tier_index = index; for &depth in TIER_DEPTHS[..tier].iter().rev() { tier_index.move_up_to(depth); @@ -335,12 +352,13 @@ impl NodeStore { } /// Removes a sequence of nodes starting at the specified index and traversing the - /// tree up to the specified depth. + /// tree up to the specified depth. The node at the `end_depth` is also removed. /// /// This method does not update any other nodes and does not recompute the tree root. - fn remove_branch(&mut self, mut index: NodeIndex, end_depth: u8) { + fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) { + let mut index: NodeIndex = index.into(); assert!(index.depth() > end_depth); - for _ in 0..(index.depth() - end_depth) { + for _ in 0..(index.depth() - end_depth + 1) { self.nodes.remove(&index); index.move_up() } diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index 845e76e..e459c90 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -509,9 +509,26 @@ fn tsmt_bottom_tier() { actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); // make sure leaves are returned correctly - let mut leaves = smt.bottom_leaves(); + let smt_clone = smt.clone(); + let mut leaves = smt_clone.bottom_leaves(); assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a)]))); assert_eq!(leaves.next(), None); + + // --- update a leaf at the bottom tier ------------------------------------------------------- + + let val_a2 = [Felt::new(3); WORD_SIZE]; + assert_eq!(smt.insert(key_a, val_a2), val_a); + + let leaf_node = build_bottom_leaf_node(&[key_b, key_a], &[val_b, val_a2]); + store.set_node(tree_root, index, leaf_node).unwrap(); + + let expected_nodes = get_non_empty_nodes(&store); + let actual_nodes = smt.inner_nodes().collect::>(); + actual_nodes.iter().for_each(|node| assert!(expected_nodes.contains(node))); + + let mut leaves = smt.bottom_leaves(); + assert_eq!(leaves.next(), Some((leaf_node, vec![(key_b, val_b), (key_a, val_a2)]))); + assert_eq!(leaves.next(), None); } #[test] diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs index eca8e5d..ec2a465 100644 --- a/src/merkle/tiered_smt/values.rs +++ b/src/merkle/tiered_smt/values.rs @@ -1,4 +1,4 @@ -use super::{get_key_prefix, is_leaf_node, BTreeMap, NodeIndex, RpoDigest, StarkField, Vec, Word}; +use super::{get_key_prefix, BTreeMap, LeafNodeIndex, RpoDigest, StarkField, Vec, Word}; use crate::utils::vec; use core::{ cmp::{Ord, Ordering}, @@ -23,7 +23,8 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// the values are the corresponding key-value pairs (or a list of key-value pairs if more that /// a single key-value pair shares the same 64-bit prefix). /// -/// The store supports lookup by the full key as well as by the 64-bit key prefix. +/// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key +/// prefix. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ValueStore { values: BTreeMap, @@ -76,26 +77,29 @@ impl ValueStore { /// /// This method assumes that the key-value pair for the specified index has already been /// removed from the store. - pub fn get_lone_sibling(&self, index: NodeIndex) -> Option<(&RpoDigest, &Word, NodeIndex)> { - debug_assert!(is_leaf_node(&index)); - + pub fn get_lone_sibling( + &self, + index: LeafNodeIndex, + ) -> Option<(&RpoDigest, &Word, LeafNodeIndex)> { // iterate over tiers from top to bottom, looking at the tiers which are strictly above // the depth of the index. This implies that only tiers at depth 32 and 48 will be // considered. For each tier, check if the parent of the index at the higher tier - // contains a single node. - for &tier in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) { + // contains a single node. The fist tier (depth 16) is excluded because we cannot move + // nodes at depth 16 to a higher tier. This implies that nodes at the first tier will + // never have "lone siblings". + for &tier_depth in TIER_DEPTHS.iter().filter(|&t| index.depth() > *t) { // compute the index of the root at a higher tier let mut parent_index = index; - parent_index.move_up_to(tier); + parent_index.move_up_to(tier_depth); // find the lone sibling, if any; we need to handle the "last node" at a given tier // separately specify the bounds for the search correctly. - let start_prefix = parent_index.value() << (MAX_DEPTH - tier); - let sibling = if start_prefix.leading_ones() as u8 == tier { + let start_prefix = parent_index.value() << (MAX_DEPTH - tier_depth); + let sibling = if start_prefix.leading_ones() as u8 == tier_depth { let mut iter = self.range(start_prefix..); iter.next().filter(|_| iter.next().is_none()) } else { - let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier); + let end_prefix = (parent_index.value() + 1) << (MAX_DEPTH - tier_depth); let mut iter = self.range(start_prefix..end_prefix); iter.next().filter(|_| iter.next().is_none()) }; @@ -346,12 +350,8 @@ fn cmp_digests(d1: &RpoDigest, d2: &RpoDigest) -> Ordering { #[cfg(test)] mod tests { - - use super::{RpoDigest, ValueStore}; - use crate::{ - merkle::{tiered_smt::values::StoreEntry, NodeIndex}, - Felt, ONE, WORD_SIZE, ZERO, - }; + use super::{LeafNodeIndex, RpoDigest, StoreEntry, ValueStore}; + use crate::{Felt, ONE, WORD_SIZE, ZERO}; #[test] fn test_insert() { @@ -569,17 +569,17 @@ mod tests { store.insert(key_b, value_b); // check sibling node for `a` - let index = NodeIndex::make(32, 0b_10101010_10101010_00011111_11111110); - let parent_index = NodeIndex::make(16, 0b_10101010_10101010); + let index = LeafNodeIndex::make(32, 0b_10101010_10101010_00011111_11111110); + let parent_index = LeafNodeIndex::make(16, 0b_10101010_10101010); assert_eq!(store.get_lone_sibling(index), Some((&key_a, &value_a, parent_index))); // check sibling node for `b` - let index = NodeIndex::make(32, 0b_11111111_11111111_00011111_11111111); - let parent_index = NodeIndex::make(16, 0b_11111111_11111111); + let index = LeafNodeIndex::make(32, 0b_11111111_11111111_00011111_11111111); + let parent_index = LeafNodeIndex::make(16, 0b_11111111_11111111); assert_eq!(store.get_lone_sibling(index), Some((&key_b, &value_b, parent_index))); // check some other sibling for some other index - let index = NodeIndex::make(32, 0b_11101010_10101010); + let index = LeafNodeIndex::make(32, 0b_11101010_10101010); assert_eq!(store.get_lone_sibling(index), None); } } From 1ac30f898936bf172d245b24f79eec4be218fcc4 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Wed, 2 Aug 2023 03:10:31 -0700 Subject: [PATCH 11/32] feat: implement ability to generate TSMT proofs --- src/merkle/mod.rs | 2 +- src/merkle/tiered_smt/mod.rs | 28 ++++++ src/merkle/tiered_smt/nodes.rs | 9 ++ src/merkle/tiered_smt/proof.rs | 134 +++++++++++++++++++++++++++++ src/merkle/tiered_smt/tests.rs | 150 ++++++++++++++++++++++++++++++++- 5 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 src/merkle/tiered_smt/proof.rs diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index c49c004..8e9979a 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -30,7 +30,7 @@ mod simple_smt; pub use simple_smt::SimpleSmt; mod tiered_smt; -pub use tiered_smt::TieredSmt; +pub use tiered_smt::{TieredSmt, TieredSmtProof}; mod mmr; pub use mmr::{Mmr, MmrPeaks, MmrProof}; diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 5379d20..c4f30ed 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -2,6 +2,7 @@ use super::{ BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, StarkField, Vec, Word, }; +use crate::utils::vec; use core::{cmp, ops::Deref}; mod nodes; @@ -10,6 +11,9 @@ use nodes::NodeStore; mod values; use values::ValueStore; +mod proof; +pub use proof::TieredSmtProof; + #[cfg(test)] mod tests; @@ -134,6 +138,30 @@ impl TieredSmt { } } + /// Returns a proof for a key-value pair defined by the specified key. + /// + /// The proof can be used to attest membership of this key-value pair in a Tiered Sparse Merkle + /// Tree defined by the same root as this tree. + pub fn prove(&self, key: RpoDigest) -> TieredSmtProof { + let (path, index, leaf_exists) = self.nodes.get_proof(&key); + + let entries = if index.depth() == Self::MAX_DEPTH { + match self.values.get_all(index.value()) { + Some(entries) => entries, + None => vec![(key, Self::EMPTY_VALUE)], + } + } else if leaf_exists { + let entry = + self.values.get_first(index_to_prefix(&index)).expect("leaf entry not found"); + debug_assert_eq!(entry.0, key); + vec![*entry] + } else { + vec![(key, Self::EMPTY_VALUE)] + }; + + TieredSmtProof::new(path, entries) + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs index 7135c6c..9db4ba3 100644 --- a/src/merkle/tiered_smt/nodes.rs +++ b/src/merkle/tiered_smt/nodes.rs @@ -86,6 +86,15 @@ impl NodeStore { Ok(path.into()) } + /// Returns a Merkle path to the node specified by the key together with a flag indicating, + /// whether this node is a leaf at depths 16, 32, or 48. + pub fn get_proof(&self, key: &RpoDigest) -> (MerklePath, NodeIndex, bool) { + let (index, leaf_exists) = self.get_leaf_index(key); + let index: NodeIndex = index.into(); + let path = self.get_path(index).expect("failed to retrieve Merkle path for a node index"); + (path, index, leaf_exists) + } + /// Returns an index at which a leaf node for the specified key should be inserted. /// /// The second value in the returned tuple is set to true if the node at the returned index diff --git a/src/merkle/tiered_smt/proof.rs b/src/merkle/tiered_smt/proof.rs new file mode 100644 index 0000000..3965faa --- /dev/null +++ b/src/merkle/tiered_smt/proof.rs @@ -0,0 +1,134 @@ +use super::{ + get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf, + EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, Vec, Word, +}; + +// CONSTANTS +// ================================================================================================ + +/// Maximum node depth. This is also the bottom tier of the tree. +const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; + +/// Value of an empty leaf. +pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE; + +// TIERED SPARSE MERKLE TREE PROOF +// ================================================================================================ + +/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a +/// Tiered Sparse Merkle tree. +/// +/// The proof consists of a Merkle path and one or more key-value entries which describe the node +/// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4], +/// the entries will contain a single item with value set to [ZERO; 4]. +pub struct TieredSmtProof { + path: MerklePath, + entries: Vec<(RpoDigest, Word)>, +} + +impl TieredSmtProof { + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Returns a new instance of [TieredSmtProof] instantiated from the specified path and entries. + /// + /// # Panics + /// Panics if: + /// - The length of the path is greater than 64. + /// - Entries is an empty vector. + /// - Entries contains more than 1 item, but the length of the path is not 64. + /// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4]. + /// - Entries contains multiple items with keys which don't share the same 64-bit prefix. + pub fn new(path: MerklePath, entries: Vec<(RpoDigest, Word)>) -> Self { + assert!(path.depth() <= MAX_DEPTH); + assert!(!entries.is_empty()); + if entries.len() > 1 { + assert!(path.depth() == MAX_DEPTH); + let prefix = get_key_prefix(&entries[0].0); + for entry in entries.iter().skip(1) { + assert_ne!(entry.1, EMPTY_VALUE); + assert_eq!(prefix, get_key_prefix(&entry.0)); + } + } + + Self { path, entries } + } + + // PROOF VERIFIER + // -------------------------------------------------------------------------------------------- + + /// Returns true if a Tiered Sparse Merkle tree with the specified root contains the provided + /// key-value pair. + /// + /// Note: this method cannot be used to assert non-membership. That is, if false is returned, + /// it does not mean that the provided key-value pair is not in the tree. + pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool { + if self.is_value_empty() { + if value != &EMPTY_VALUE { + return false; + } + // if the proof is for an empty value, we can verify it against any key which has a + // common prefix with the key storied in entries, but the prefix must be greater than + // the path length + let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0); + if common_prefix_tier < self.path.depth() { + return false; + } + } else if !self.entries.contains(&(*key, *value)) { + return false; + } + + // make sure the Merkle path resolves to the correct root + root == &self.compute_root() + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the value associated with the specific key according to this proof, or None if + /// this proof does not contain a value for the specified key. + /// + /// A key-value pair generated by using this method should pass the `verify_membership()` check. + pub fn get(&self, key: &RpoDigest) -> Option { + if self.is_value_empty() { + let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0); + if common_prefix_tier < self.path.depth() { + None + } else { + Some(EMPTY_VALUE) + } + } else { + self.entries.iter().find(|(k, _)| k == key).map(|(_, value)| *value) + } + } + + /// Computes the root of a Tiered Sparse Merkle tree to which this proof resolve. + pub fn compute_root(&self) -> RpoDigest { + let node = self.build_node(); + let index = LeafNodeIndex::from_key(&self.entries[0].0, self.path.depth()); + self.path + .compute_root(index.value(), node) + .expect("failed to compute Merkle path root") + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Returns true if the proof is for an empty value. + fn is_value_empty(&self) -> bool { + self.entries[0].1 == EMPTY_VALUE + } + + /// Converts the entries contained in this proof into a node value for node at the base of the + /// path contained in this proof. + fn build_node(&self) -> RpoDigest { + let depth = self.path.depth(); + if self.is_value_empty() { + EmptySubtreeRoots::empty_hashes(MAX_DEPTH)[depth as usize] + } else if depth == MAX_DEPTH { + hash_bottom_leaf(&self.entries) + } else { + let (key, value) = self.entries[0]; + hash_upper_leaf(key, value, depth) + } + } +} diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index e459c90..c1b7649 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -1,5 +1,5 @@ use super::{ - super::{super::ONE, Felt, MerkleStore, WORD_SIZE, ZERO}, + super::{super::ONE, empty_roots::EMPTY_WORD, Felt, MerkleStore, WORD_SIZE, ZERO}, EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word, }; @@ -587,6 +587,154 @@ fn tsmt_bottom_tier_two() { assert_eq!(leaves.next(), None); } +// GET PROOF TESTS +// ================================================================================================ + +#[test] +fn tsmt_get_proof() { + let mut smt = TieredSmt::default(); + + // --- insert a value into the tree --------------------------------------- + let raw_a = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // --- insert a value with the same 48-bit prefix into the tree ----------- + let raw_b = 0b_01010101_01010101_11111111_11111111_10110101_10101010_10111100_00000000_u64; + let key_b = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_b)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + let smt_alt = smt.clone(); + + // --- insert a value with the same 32-bit prefix into the tree ----------- + let raw_c = 0b_01010101_01010101_11111111_11111111_11111101_10101010_10111100_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_c, value_c); + + // --- insert a value with the same 64-bit prefix as A into the tree ------ + let raw_d = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000000_u64; + let key_d = RpoDigest::from([ZERO, ZERO, ONE, Felt::new(raw_d)]); + let value_d = [ONE, ZERO, ZERO, ZERO]; + smt.insert(key_d, value_d); + + // at this point the tree looks as follows: + // - A and D are located in the same node at depth 64. + // - B is located at depth 64 and shares the same 48-bit prefix with A and D. + // - C is located at depth 48 and shares the same 32-bit prefix with A, B, and D. + + // --- generate proof for key A and test that it verifies correctly ------- + let proof = smt.prove(key_a); + assert!(proof.verify_membership(&key_a, &value_a, &smt.root())); + + assert!(!proof.verify_membership(&key_a, &value_b, &smt.root())); + assert!(!proof.verify_membership(&key_a, &EMPTY_WORD, &smt.root())); + assert!(!proof.verify_membership(&key_b, &value_a, &smt.root())); + assert!(!proof.verify_membership(&key_a, &value_a, &smt_alt.root())); + + assert_eq!(proof.get(&key_a), Some(value_a)); + assert_eq!(proof.get(&key_b), None); + + // since A and D are stored in the same node, we should be able to use the proof to verify + // membership of D + assert!(proof.verify_membership(&key_d, &value_d, &smt.root())); + assert_eq!(proof.get(&key_d), Some(value_d)); + + // --- generate proof for key B and test that it verifies correctly ------- + let proof = smt.prove(key_b); + assert!(proof.verify_membership(&key_b, &value_b, &smt.root())); + + assert!(!proof.verify_membership(&key_b, &value_a, &smt.root())); + assert!(!proof.verify_membership(&key_b, &EMPTY_WORD, &smt.root())); + assert!(!proof.verify_membership(&key_a, &value_b, &smt.root())); + assert!(!proof.verify_membership(&key_b, &value_b, &smt_alt.root())); + + assert_eq!(proof.get(&key_b), Some(value_b)); + assert_eq!(proof.get(&key_a), None); + + // --- generate proof for key C and test that it verifies correctly ------- + let proof = smt.prove(key_c); + assert!(proof.verify_membership(&key_c, &value_c, &smt.root())); + + assert!(!proof.verify_membership(&key_c, &value_a, &smt.root())); + assert!(!proof.verify_membership(&key_c, &EMPTY_WORD, &smt.root())); + assert!(!proof.verify_membership(&key_a, &value_c, &smt.root())); + assert!(!proof.verify_membership(&key_c, &value_c, &smt_alt.root())); + + assert_eq!(proof.get(&key_c), Some(value_c)); + assert_eq!(proof.get(&key_b), None); + + // --- generate proof for key D and test that it verifies correctly ------- + let proof = smt.prove(key_d); + assert!(proof.verify_membership(&key_d, &value_d, &smt.root())); + + assert!(!proof.verify_membership(&key_d, &value_b, &smt.root())); + assert!(!proof.verify_membership(&key_d, &EMPTY_WORD, &smt.root())); + assert!(!proof.verify_membership(&key_b, &value_d, &smt.root())); + assert!(!proof.verify_membership(&key_d, &value_d, &smt_alt.root())); + + assert_eq!(proof.get(&key_d), Some(value_d)); + assert_eq!(proof.get(&key_b), None); + + // since A and D are stored in the same node, we should be able to use the proof to verify + // membership of A + assert!(proof.verify_membership(&key_a, &value_a, &smt.root())); + assert_eq!(proof.get(&key_a), Some(value_a)); + + // --- generate proof for an empty key at depth 64 ------------------------ + // this key has the same 48-bit prefix as A but is different from B + let raw = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000011_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + + let proof = smt.prove(key); + assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root())); + + assert!(!proof.verify_membership(&key, &value_a, &smt.root())); + assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root())); + + assert_eq!(proof.get(&key), Some(EMPTY_WORD)); + assert_eq!(proof.get(&key_b), None); + + // the same proof should verify against any key with the same 64-bit prefix + let key2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]); + assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root())); + assert_eq!(proof.get(&key2), Some(EMPTY_WORD)); + + // but verifying if against a key with the same 63-bit prefix (or smaller) should fail + let raw3 = 0b_01010101_01010101_11111111_11111111_10110101_10101010_11111100_00000010_u64; + let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]); + assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root())); + assert_eq!(proof.get(&key3), None); + + // --- generate proof for an empty key at depth 48 ------------------------ + // this key has the same 32-prefix as A, B, C, and D, but is different from C + let raw = 0b_01010101_01010101_11111111_11111111_00110101_10101010_11111100_00000000_u64; + let key = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + + let proof = smt.prove(key); + assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root())); + + assert!(!proof.verify_membership(&key, &value_a, &smt.root())); + assert!(!proof.verify_membership(&key, &EMPTY_WORD, &smt_alt.root())); + + assert_eq!(proof.get(&key), Some(EMPTY_WORD)); + assert_eq!(proof.get(&key_b), None); + + // the same proof should verify against any key with the same 48-bit prefix + let raw2 = 0b_01010101_01010101_11111111_11111111_00110101_10101010_01111100_00000000_u64; + let key2 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw2)]); + assert!(proof.verify_membership(&key2, &EMPTY_WORD, &smt.root())); + assert_eq!(proof.get(&key2), Some(EMPTY_WORD)); + + // but verifying against a key with the same 47-bit prefix (or smaller) should fail + let raw3 = 0b_01010101_01010101_11111111_11111111_00110101_10101011_11111100_00000000_u64; + let key3 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw3)]); + assert!(!proof.verify_membership(&key3, &EMPTY_WORD, &smt.root())); + assert_eq!(proof.get(&key3), None); +} + // ERROR TESTS // ================================================================================================ From 33ef78f8f5b2194206c1c26ac8cef941294a4832 Mon Sep 17 00:00:00 2001 From: "Augusto F. Hack" Date: Thu, 3 Aug 2023 14:58:51 +0200 Subject: [PATCH 12/32] tsmt: add basic traits and into/from parts methods --- src/merkle/tiered_smt/proof.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/merkle/tiered_smt/proof.rs b/src/merkle/tiered_smt/proof.rs index 3965faa..0e7fcbf 100644 --- a/src/merkle/tiered_smt/proof.rs +++ b/src/merkle/tiered_smt/proof.rs @@ -21,6 +21,7 @@ pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE; /// The proof consists of a Merkle path and one or more key-value entries which describe the node /// located at the base of the path. If the node at the base of the path resolves to [ZERO; 4], /// the entries will contain a single item with value set to [ZERO; 4]. +#[derive(PartialEq, Eq, Debug)] pub struct TieredSmtProof { path: MerklePath, entries: Vec<(RpoDigest, Word)>, @@ -38,7 +39,11 @@ impl TieredSmtProof { /// - Entries contains more than 1 item, but the length of the path is not 64. /// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4]. /// - Entries contains multiple items with keys which don't share the same 64-bit prefix. - pub fn new(path: MerklePath, entries: Vec<(RpoDigest, Word)>) -> Self { + pub fn new(path: MerklePath, entries: I) -> Self + where + I: IntoIterator, + { + let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect(); assert!(path.depth() <= MAX_DEPTH); assert!(!entries.is_empty()); if entries.len() > 1 { @@ -110,6 +115,11 @@ impl TieredSmtProof { .expect("failed to compute Merkle path root") } + /// Consume the proof and returns its parts. + pub fn into_parts(self) -> (MerklePath, Vec<(RpoDigest, Word)>) { + (self.path, self.entries) + } + // HELPER METHODS // -------------------------------------------------------------------------------------------- From 83b69464328d710a5720523d16b931d073f6bd5b Mon Sep 17 00:00:00 2001 From: "Augusto F. Hack" Date: Thu, 3 Aug 2023 18:57:19 +0200 Subject: [PATCH 13/32] tsmt: return error code instead of panic --- src/merkle/tiered_smt/error.rs | 49 ++++++++++++++++++++++++++++++++++ src/merkle/tiered_smt/mod.rs | 5 +++- src/merkle/tiered_smt/proof.rs | 34 +++++++++++++++++------ 3 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 src/merkle/tiered_smt/error.rs diff --git a/src/merkle/tiered_smt/error.rs b/src/merkle/tiered_smt/error.rs new file mode 100644 index 0000000..fdc6123 --- /dev/null +++ b/src/merkle/tiered_smt/error.rs @@ -0,0 +1,49 @@ +use core::fmt::Display; + +#[derive(Debug, PartialEq, Eq)] +pub enum TieredSmtProofError { + EntriesEmpty, + PathTooLong, + NotATierPath(u8), + MultipleEntriesOutsideLastTier, + EmptyValueNotAllowed, + UnmatchingPrefixes(u64, u64), +} + +impl Display for TieredSmtProofError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + TieredSmtProofError::EntriesEmpty => { + write!(f, "Missing entries for tiered sparse merkle tree proof") + } + TieredSmtProofError::PathTooLong => { + write!( + f, + "Path longer than maximum depth of 64 for tiered sparse merkle tree proof" + ) + } + TieredSmtProofError::NotATierPath(got) => { + write!( + f, + "Path length does not correspond to a tier. Got {} Expected one of 16,32,48,64", + got + ) + } + TieredSmtProofError::MultipleEntriesOutsideLastTier => { + write!(f, "Multiple entries are only allowed for the last tier (depth 64)") + } + TieredSmtProofError::EmptyValueNotAllowed => { + write!( + f, + "The empty value [0,0,0,0] is not allowed inside a tiered sparse merkle tree" + ) + } + TieredSmtProofError::UnmatchingPrefixes(first, second) => { + write!(f, "Not all leaves have the same prefix. First {} second {}", first, second) + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for TieredSmtProofError {} diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index c4f30ed..a0bd723 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -14,6 +14,9 @@ use values::ValueStore; mod proof; pub use proof::TieredSmtProof; +mod error; +pub use error::TieredSmtProofError; + #[cfg(test)] mod tests; @@ -159,7 +162,7 @@ impl TieredSmt { vec![(key, Self::EMPTY_VALUE)] }; - TieredSmtProof::new(path, entries) + TieredSmtProof::new(path, entries).expect("Bug detected, TSMT produced invalid proof") } // STATE MUTATORS diff --git a/src/merkle/tiered_smt/proof.rs b/src/merkle/tiered_smt/proof.rs index 0e7fcbf..eae8e38 100644 --- a/src/merkle/tiered_smt/proof.rs +++ b/src/merkle/tiered_smt/proof.rs @@ -1,6 +1,6 @@ use super::{ get_common_prefix_tier_depth, get_key_prefix, hash_bottom_leaf, hash_upper_leaf, - EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, Vec, Word, + EmptySubtreeRoots, LeafNodeIndex, MerklePath, RpoDigest, TieredSmtProofError, Vec, Word, }; // CONSTANTS @@ -12,6 +12,9 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// Value of an empty leaf. pub const EMPTY_VALUE: Word = super::TieredSmt::EMPTY_VALUE; +/// Depths at which leaves can exist in a tiered SMT. +pub const TIER_DEPTHS: [u8; 4] = super::TieredSmt::TIER_DEPTHS; + // TIERED SPARSE MERKLE TREE PROOF // ================================================================================================ @@ -39,23 +42,38 @@ impl TieredSmtProof { /// - Entries contains more than 1 item, but the length of the path is not 64. /// - Entries contains more than 1 item, and one of the items has value set to [ZERO; 4]. /// - Entries contains multiple items with keys which don't share the same 64-bit prefix. - pub fn new(path: MerklePath, entries: I) -> Self + pub fn new(path: MerklePath, entries: I) -> Result where I: IntoIterator, { let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect(); - assert!(path.depth() <= MAX_DEPTH); - assert!(!entries.is_empty()); + + if !TIER_DEPTHS.into_iter().any(|e| e == path.depth()) { + return Err(TieredSmtProofError::NotATierPath(path.depth())); + } + + if entries.is_empty() { + return Err(TieredSmtProofError::EntriesEmpty); + } + if entries.len() > 1 { - assert!(path.depth() == MAX_DEPTH); + if path.depth() != MAX_DEPTH { + return Err(TieredSmtProofError::MultipleEntriesOutsideLastTier); + } + let prefix = get_key_prefix(&entries[0].0); for entry in entries.iter().skip(1) { - assert_ne!(entry.1, EMPTY_VALUE); - assert_eq!(prefix, get_key_prefix(&entry.0)); + if entry.1 == EMPTY_VALUE { + return Err(TieredSmtProofError::EmptyValueNotAllowed); + } + let current = get_key_prefix(&entry.0); + if prefix != current { + return Err(TieredSmtProofError::UnmatchingPrefixes(prefix, current)); + } } } - Self { path, entries } + Ok(Self { path, entries }) } // PROOF VERIFIER From 5c6a20cb6094fba6b2e915c87cf44409b2659b3f Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Fri, 4 Aug 2023 22:36:45 -0700 Subject: [PATCH 14/32] fix: bug in TSMT for depth 64 removal --- src/merkle/tiered_smt/mod.rs | 21 ++++-- src/merkle/tiered_smt/nodes.rs | 57 ++++++++++++--- src/merkle/tiered_smt/tests.rs | 125 +++++++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 15 deletions(-) diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index a0bd723..fbdf2d3 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -41,7 +41,7 @@ mod tests; /// To differentiate between internal and leaf nodes, node values are computed as follows: /// - Internal nodes: hash(left_child, right_child). /// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth). -/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n, domain=64]). +/// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64). #[derive(Debug, Clone, PartialEq, Eq)] pub struct TieredSmt { root: RpoDigest, @@ -306,10 +306,23 @@ impl TieredSmt { debug_assert!(leaf_exists); // if the leaf is at the bottom tier and after removing the key-value pair from it, the - // leaf is still not empty, just recompute its hash and update the leaf node. + // leaf is still not empty, we either just update it, or move it up to a higher tier (if + // the leaf doesn't have siblings at lower tiers) if index.depth() == Self::MAX_DEPTH { - if let Some(values) = self.values.get_all(index.value()) { - let node = hash_bottom_leaf(&values); + if let Some(entries) = self.values.get_all(index.value()) { + // if there is only one key-value pair left at the bottom leaf, and it can be + // moved up to a higher tier, truncate the branch and return + if entries.len() == 1 { + let new_depth = self.nodes.get_last_single_child_parent_depth(index.value()); + if new_depth != Self::MAX_DEPTH { + let node = hash_upper_leaf(entries[0].0, entries[0].1, new_depth); + self.root = self.nodes.truncate_branch(index.value(), new_depth, node); + return old_value; + } + } + + // otherwise just recompute the leaf hash and update the leaf node + let node = hash_bottom_leaf(&entries); self.root = self.nodes.update_leaf_node(index, node); return old_value; }; diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs index 9db4ba3..0d94091 100644 --- a/src/merkle/tiered_smt/nodes.rs +++ b/src/merkle/tiered_smt/nodes.rs @@ -118,6 +118,24 @@ impl NodeStore { (index, self.bottom_leaves.contains(&index.value())) } + /// Traverses the tree up from the bottom tier starting at the specified leaf index and + /// returns the depth of the first node which hash more than one child. The returned depth + /// is rounded up to the next tier. + pub fn get_last_single_child_parent_depth(&self, leaf_index: u64) -> u8 { + let mut index = NodeIndex::new_unchecked(MAX_DEPTH, leaf_index); + + for _ in (TIER_DEPTHS[0]..MAX_DEPTH).rev() { + let sibling_index = index.sibling(); + if self.nodes.contains_key(&sibling_index) { + break; + } + index.move_up(); + } + + let tier = (index.depth() - 1) / TIER_SIZE; + TIER_DEPTHS[tier as usize] + } + // ITERATORS // -------------------------------------------------------------------------------------------- @@ -201,15 +219,6 @@ impl NodeStore { debug_assert_eq!(removed_leaf.depth(), retained_leaf.depth()); debug_assert!(removed_leaf.depth() > new_depth); - // clear leaf flags - if removed_leaf.depth() == MAX_DEPTH { - self.bottom_leaves.remove(&removed_leaf.value()); - self.bottom_leaves.remove(&retained_leaf.value()); - } else { - self.upper_leaves.remove(&removed_leaf); - self.upper_leaves.remove(&retained_leaf); - } - // remove the branches leading up to the tier to which the retained leaf is to be moved self.remove_branch(removed_leaf, new_depth); self.remove_branch(retained_leaf, new_depth); @@ -298,6 +307,25 @@ impl NodeStore { self.update_leaf_node(index, node) } + /// Truncates a branch starting with specified leaf at the bottom tier to new depth. + /// + /// This involves removing the part of the branch below the new depth, and then inserting a new + /// // node at the new depth. + pub fn truncate_branch( + &mut self, + leaf_index: u64, + new_depth: u8, + node: RpoDigest, + ) -> RpoDigest { + debug_assert!(self.bottom_leaves.contains(&leaf_index)); + + let mut leaf_index = LeafNodeIndex::new(NodeIndex::new_unchecked(MAX_DEPTH, leaf_index)); + self.remove_branch(leaf_index, new_depth); + + leaf_index.move_up_to(new_depth); + self.insert_leaf_node(leaf_index, node) + } + // HELPER METHODS // -------------------------------------------------------------------------------------------- @@ -360,11 +388,18 @@ impl NodeStore { } } - /// Removes a sequence of nodes starting at the specified index and traversing the - /// tree up to the specified depth. The node at the `end_depth` is also removed. + /// Removes a sequence of nodes starting at the specified index and traversing the tree up to + /// the specified depth. The node at the `end_depth` is also removed, and the appropriate leaf + /// flag is cleared. /// /// This method does not update any other nodes and does not recompute the tree root. fn remove_branch(&mut self, index: LeafNodeIndex, end_depth: u8) { + if index.depth() == MAX_DEPTH { + self.bottom_leaves.remove(&index.value()); + } else { + self.upper_leaves.remove(&index); + } + let mut index: NodeIndex = index.into(); assert!(index.depth() > end_depth); for _ in 0..(index.depth() - end_depth + 1) { diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index c1b7649..61e6081 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -460,6 +460,131 @@ fn tsmt_delete_64() { assert_eq!(smt, smt0); } +#[test] +fn tsmt_delete_64_leaf_promotion() { + let mut smt = TieredSmt::default(); + + // --- delete from bottom tier (no promotion to upper tiers) -------------- + + // insert a value into the tree + let raw_a = 0b_01010101_01010101_11111111_11111111_10101010_10101010_11111111_00000000_u64; + let key_a = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_a)]); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // insert another value with a key having the same 64-bit prefix + let key_b = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw_a)]); + let value_b = [ONE, ONE, ONE, ZERO]; + smt.insert(key_b, value_b); + + // insert a value with a key which shared the same 48-bit prefix + let raw_c = 0b_01010101_01010101_11111111_11111111_10101010_10101010_00111111_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + let value_c = [ONE, ONE, ZERO, ZERO]; + smt.insert(key_c, value_c); + + // delete entry A and compare to the tree which was built from B and C + smt.insert(key_a, EMPTY_WORD); + + let mut expected_smt = TieredSmt::default(); + expected_smt.insert(key_b, value_b); + expected_smt.insert(key_c, value_c); + assert_eq!(smt, expected_smt); + + // entries B and C should stay at depth 64 + assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 64); + assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 64); + + // --- delete from bottom tier (promotion to depth 48) -------------------- + + let mut smt = TieredSmt::default(); + smt.insert(key_a, value_a); + smt.insert(key_b, value_b); + + // insert a value with a key which shared the same 32-bit prefix + let raw_c = 0b_01010101_01010101_11111111_11111111_11101010_10101010_11111111_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + smt.insert(key_c, value_c); + + // delete entry A and compare to the tree which was built from B and C + smt.insert(key_a, EMPTY_WORD); + + let mut expected_smt = TieredSmt::default(); + expected_smt.insert(key_b, value_b); + expected_smt.insert(key_c, value_c); + assert_eq!(smt, expected_smt); + + // entry B moves to depth 48, entry C stays at depth 48 + assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 48); + assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 48); + + // --- delete from bottom tier (promotion to depth 32) -------------------- + + let mut smt = TieredSmt::default(); + smt.insert(key_a, value_a); + smt.insert(key_b, value_b); + + // insert a value with a key which shared the same 16-bit prefix + let raw_c = 0b_01010101_01010101_01111111_11111111_10101010_10101010_11111111_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + smt.insert(key_c, value_c); + + // delete entry A and compare to the tree which was built from B and C + smt.insert(key_a, EMPTY_WORD); + + let mut expected_smt = TieredSmt::default(); + expected_smt.insert(key_b, value_b); + expected_smt.insert(key_c, value_c); + assert_eq!(smt, expected_smt); + + // entry B moves to depth 32, entry C stays at depth 32 + assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 32); + assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 32); + + // --- delete from bottom tier (promotion to depth 16) -------------------- + + let mut smt = TieredSmt::default(); + smt.insert(key_a, value_a); + smt.insert(key_b, value_b); + + // insert a value with a key which shared prefix < 16 bits + let raw_c = 0b_01010101_01010100_11111111_11111111_10101010_10101010_11111111_00000000_u64; + let key_c = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw_c)]); + smt.insert(key_c, value_c); + + // delete entry A and compare to the tree which was built from B and C + smt.insert(key_a, EMPTY_WORD); + + let mut expected_smt = TieredSmt::default(); + expected_smt.insert(key_b, value_b); + expected_smt.insert(key_c, value_c); + assert_eq!(smt, expected_smt); + + // entry B moves to depth 16, entry C stays at depth 16 + assert_eq!(smt.nodes.get_leaf_index(&key_b).0.depth(), 16); + assert_eq!(smt.nodes.get_leaf_index(&key_c).0.depth(), 16); +} + +#[test] +fn test_order_sensitivity() { + let raw = 0b_10101010_10101010_00011111_11111111_10010110_10010011_11100000_00000001_u64; + let value = [ONE; WORD_SIZE]; + + let key_1 = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let key_2 = RpoDigest::from([ONE, ONE, ZERO, Felt::new(raw)]); + + let mut smt_1 = TieredSmt::default(); + + smt_1.insert(key_1, value); + smt_1.insert(key_2, value); + smt_1.insert(key_2, [ZERO; WORD_SIZE]); + + let mut smt_2 = TieredSmt::default(); + smt_2.insert(key_1, value); + + assert_eq!(smt_1.root(), smt_2.root()); +} + // BOTTOM TIER TESTS // ================================================================================================ From b3e7578ab228c82936a57ae98e7d62987d4357aa Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Fri, 4 Aug 2023 22:46:23 -0700 Subject: [PATCH 15/32] fix: misspelled variant name in TieredSmtProofError --- src/merkle/tiered_smt/error.rs | 45 +++++++++++++++++----------------- src/merkle/tiered_smt/proof.rs | 2 +- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/merkle/tiered_smt/error.rs b/src/merkle/tiered_smt/error.rs index fdc6123..92f7dea 100644 --- a/src/merkle/tiered_smt/error.rs +++ b/src/merkle/tiered_smt/error.rs @@ -3,11 +3,11 @@ use core::fmt::Display; #[derive(Debug, PartialEq, Eq)] pub enum TieredSmtProofError { EntriesEmpty, - PathTooLong, - NotATierPath(u8), - MultipleEntriesOutsideLastTier, EmptyValueNotAllowed, - UnmatchingPrefixes(u64, u64), + MismatchedPrefixes(u64, u64), + MultipleEntriesOutsideLastTier, + NotATierPath(u8), + PathTooLong, } impl Display for TieredSmtProofError { @@ -16,31 +16,30 @@ impl Display for TieredSmtProofError { TieredSmtProofError::EntriesEmpty => { write!(f, "Missing entries for tiered sparse merkle tree proof") } + TieredSmtProofError::EmptyValueNotAllowed => { + write!( + f, + "The empty value [0, 0, 0, 0] is not allowed inside a tiered sparse merkle tree" + ) + } + TieredSmtProofError::MismatchedPrefixes(first, second) => { + write!(f, "Not all leaves have the same prefix. First {first} second {second}") + } + TieredSmtProofError::MultipleEntriesOutsideLastTier => { + write!(f, "Multiple entries are only allowed for the last tier (depth 64)") + } + TieredSmtProofError::NotATierPath(got) => { + write!( + f, + "Path length does not correspond to a tier. Got {got} Expected one of 16, 32, 48, 64" + ) + } TieredSmtProofError::PathTooLong => { write!( f, "Path longer than maximum depth of 64 for tiered sparse merkle tree proof" ) } - TieredSmtProofError::NotATierPath(got) => { - write!( - f, - "Path length does not correspond to a tier. Got {} Expected one of 16,32,48,64", - got - ) - } - TieredSmtProofError::MultipleEntriesOutsideLastTier => { - write!(f, "Multiple entries are only allowed for the last tier (depth 64)") - } - TieredSmtProofError::EmptyValueNotAllowed => { - write!( - f, - "The empty value [0,0,0,0] is not allowed inside a tiered sparse merkle tree" - ) - } - TieredSmtProofError::UnmatchingPrefixes(first, second) => { - write!(f, "Not all leaves have the same prefix. First {} second {}", first, second) - } } } } diff --git a/src/merkle/tiered_smt/proof.rs b/src/merkle/tiered_smt/proof.rs index eae8e38..28ac288 100644 --- a/src/merkle/tiered_smt/proof.rs +++ b/src/merkle/tiered_smt/proof.rs @@ -68,7 +68,7 @@ impl TieredSmtProof { } let current = get_key_prefix(&entry.0); if prefix != current { - return Err(TieredSmtProofError::UnmatchingPrefixes(prefix, current)); + return Err(TieredSmtProofError::MismatchedPrefixes(prefix, current)); } } } From f71d98970baede92671976031227d41082a8d515 Mon Sep 17 00:00:00 2001 From: "Augusto F. Hack" Date: Mon, 7 Aug 2023 11:13:24 +0200 Subject: [PATCH 16/32] tsmt: export smt error --- src/merkle/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 9580973..1e687e7 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -27,7 +27,7 @@ mod simple_smt; pub use simple_smt::SimpleSmt; mod tiered_smt; -pub use tiered_smt::{TieredSmt, TieredSmtProof}; +pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError}; mod mmr; pub use mmr::{Mmr, MmrPeaks, MmrProof}; From 8cf5e9fd2c325eb67f8d05150fea3c8d59351056 Mon Sep 17 00:00:00 2001 From: "Augusto F. Hack" Date: Tue, 8 Aug 2023 17:04:25 +0200 Subject: [PATCH 17/32] feature: add conditional support for serde --- Cargo.toml | 2 + src/hash/blake/mod.rs | 21 ++++++- src/hash/rpo/digest.rs | 99 +++++++++++++++++++++++++++++---- src/merkle/delta.rs | 3 + src/merkle/index.rs | 1 + src/merkle/merkle_tree.rs | 1 + src/merkle/mmr/accumulator.rs | 1 + src/merkle/mmr/full.rs | 1 + src/merkle/mmr/proof.rs | 1 + src/merkle/node.rs | 1 + src/merkle/partial_mt/mod.rs | 1 + src/merkle/path.rs | 1 + src/merkle/simple_smt/mod.rs | 2 + src/merkle/store/mod.rs | 2 + src/merkle/tiered_smt/mod.rs | 1 + src/merkle/tiered_smt/nodes.rs | 1 + src/merkle/tiered_smt/values.rs | 2 + src/utils/mod.rs | 77 ++++++++++++++++++++++++- 18 files changed, 204 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5caa2be..3a863d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,12 +27,14 @@ harness = false [features] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] +serde = ["winter_math/serde", "dep:serde", "serde/alloc"] [dependencies] blake3 = { version = "1.4", default-features = false } winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } +serde = { version = "1.0", features = [ "derive" ], optional = true, default-features = false } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/hash/blake/mod.rs b/src/hash/blake/mod.rs index 91c9bca..9f02eec 100644 --- a/src/hash/blake/mod.rs +++ b/src/hash/blake/mod.rs @@ -1,5 +1,8 @@ use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField}; -use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use crate::utils::{ + bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable, + DeserializationError, HexParseError, Serializable, +}; use core::{ mem::{size_of, transmute, transmute_copy}, ops::Deref, @@ -24,6 +27,8 @@ const DIGEST20_BYTES: usize = 20; /// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32 /// bytes. #[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] pub struct Blake3Digest([u8; N]); impl Default for Blake3Digest { @@ -52,6 +57,20 @@ impl From<[u8; N]> for Blake3Digest { } } +impl From> for String { + fn from(value: Blake3Digest) -> Self { + bytes_to_hex_string(value.as_bytes()) + } +} + +impl TryFrom<&str> for Blake3Digest { + type Error = HexParseError; + + fn try_from(value: &str) -> Result { + hex_to_bytes(value).map(|v| v.into()) + } +} + impl Serializable for Blake3Digest { fn write_into(&self, target: &mut W) { target.write_bytes(&self.0); diff --git a/src/hash/rpo/digest.rs b/src/hash/rpo/digest.rs index 0e6c310..efeda3f 100644 --- a/src/hash/rpo/digest.rs +++ b/src/hash/rpo/digest.rs @@ -1,13 +1,19 @@ use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO}; use crate::utils::{ - string::String, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, + bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable, + DeserializationError, HexParseError, Serializable, }; use core::{cmp::Ordering, fmt::Display, ops::Deref}; +/// The number of bytes needed to encoded a digest +pub const DIGEST_BYTES: usize = 32; + // DIGEST TRAIT IMPLEMENTATIONS // ================================================================================================ #[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] pub struct RpoDigest([Felt; DIGEST_SIZE]); impl RpoDigest { @@ -19,7 +25,7 @@ impl RpoDigest { self.as_ref() } - pub fn as_bytes(&self) -> [u8; 32] { + pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] { ::as_bytes(self) } @@ -32,8 +38,8 @@ impl RpoDigest { } impl Digest for RpoDigest { - fn as_bytes(&self) -> [u8; 32] { - let mut result = [0; 32]; + fn as_bytes(&self) -> [u8; DIGEST_BYTES] { + let mut result = [0; DIGEST_BYTES]; result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes()); result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes()); @@ -107,18 +113,73 @@ impl From for [u64; DIGEST_SIZE] { } } -impl From<&RpoDigest> for [u8; 32] { +impl From<&RpoDigest> for [u8; DIGEST_BYTES] { fn from(value: &RpoDigest) -> Self { value.as_bytes() } } -impl From for [u8; 32] { +impl From for [u8; DIGEST_BYTES] { fn from(value: RpoDigest) -> Self { value.as_bytes() } } +impl From for String { + fn from(value: RpoDigest) -> Self { + bytes_to_hex_string(value.as_bytes()) + } +} + +impl From<&RpoDigest> for String { + fn from(value: &RpoDigest) -> Self { + (*value).into() + } +} + +impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: [u8; DIGEST_BYTES]) -> Result { + // Note: the input length is known, the conversion from slice to array must succeed so the + // `unwrap`s below are safe + let a = u64::from_le_bytes(value[0..8].try_into().unwrap()); + let b = u64::from_le_bytes(value[8..16].try_into().unwrap()); + let c = u64::from_le_bytes(value[16..24].try_into().unwrap()); + let d = u64::from_le_bytes(value[24..32].try_into().unwrap()); + + if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) { + return Err(HexParseError::OutOfRange); + } + + Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)])) + } +} + +impl TryFrom<&str> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: &str) -> Result { + hex_to_bytes(value).and_then(|v| v.try_into()) + } +} + +impl TryFrom for RpoDigest { + type Error = HexParseError; + + fn try_from(value: String) -> Result { + value.as_str().try_into() + } +} + +impl TryFrom<&String> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: &String) -> Result { + value.as_str().try_into() + } +} + impl Deref for RpoDigest { type Target = [Felt; DIGEST_SIZE]; @@ -158,9 +219,8 @@ impl PartialOrd for RpoDigest { impl Display for RpoDigest { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - for byte in self.as_bytes() { - write!(f, "{byte:02x}")?; - } + let encoded: String = self.into(); + write!(f, "{}", encoded)?; Ok(()) } } @@ -170,8 +230,7 @@ impl Display for RpoDigest { #[cfg(test)] mod tests { - - use super::{Deserializable, Felt, RpoDigest, Serializable}; + use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES}; use crate::utils::SliceReader; use rand_utils::rand_value; @@ -186,11 +245,27 @@ mod tests { let mut bytes = vec![]; d1.write_into(&mut bytes); - assert_eq!(32, bytes.len()); + assert_eq!(DIGEST_BYTES, bytes.len()); let mut reader = SliceReader::new(&bytes); let d2 = RpoDigest::read_from(&mut reader).unwrap(); assert_eq!(d1, d2); } + + #[cfg(feature = "std")] + #[test] + fn digest_encoding() { + let digest = RpoDigest([ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]); + + let string: String = digest.into(); + let round_trip: RpoDigest = string.try_into().expect("decoding failed"); + + assert_eq!(digest, round_trip); + } } diff --git a/src/merkle/delta.rs b/src/merkle/delta.rs index 71b822a..cf6d1b9 100644 --- a/src/merkle/delta.rs +++ b/src/merkle/delta.rs @@ -13,6 +13,7 @@ use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt}; /// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the /// differences between the initial and final Merkle tree states. #[derive(Default, Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>); // MERKLE TREE DELTA @@ -26,6 +27,7 @@ pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>); /// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values. #[cfg(not(test))] #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleTreeDelta { depth: u8, cleared_slots: Vec, @@ -107,6 +109,7 @@ pub fn merkle_tree_delta>( // -------------------------------------------------------------------------------------------- #[cfg(test)] #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleTreeDelta { pub depth: u8, pub cleared_slots: Vec, diff --git a/src/merkle/index.rs b/src/merkle/index.rs index 3a79ac0..25c9282 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -21,6 +21,7 @@ use core::fmt::Display; /// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child /// $(1, 1)$. #[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct NodeIndex { depth: u8, value: u64, diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index cfb61bc..206543a 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -8,6 +8,7 @@ use winter_math::log2; /// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two). #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleTree { nodes: Vec, } diff --git a/src/merkle/mmr/accumulator.rs b/src/merkle/mmr/accumulator.rs index 0729c94..a610fe7 100644 --- a/src/merkle/mmr/accumulator.rs +++ b/src/merkle/mmr/accumulator.rs @@ -4,6 +4,7 @@ use super::{ }; #[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MmrPeaks { /// The number of leaves is used to differentiate accumulators that have the same number of /// peaks. This happens because the number of peaks goes up-and-down as the structure is used diff --git a/src/merkle/mmr/full.rs b/src/merkle/mmr/full.rs index d2fbbeb..c3dd3ac 100644 --- a/src/merkle/mmr/full.rs +++ b/src/merkle/mmr/full.rs @@ -29,6 +29,7 @@ use std::error::Error; /// Since this is a full representation of the MMR, elements are never removed and the MMR will /// grow roughly `O(2n)` in number of leaf elements. #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct Mmr { /// Refer to the `forest` method documentation for details of the semantics of this value. pub(super) forest: usize, diff --git a/src/merkle/mmr/proof.rs b/src/merkle/mmr/proof.rs index 0904b83..d9b4bcf 100644 --- a/src/merkle/mmr/proof.rs +++ b/src/merkle/mmr/proof.rs @@ -3,6 +3,7 @@ use super::super::MerklePath; use super::full::{high_bitmask, leaf_to_corresponding_tree}; #[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MmrProof { /// The state of the MMR when the MmrProof was created. pub forest: usize, diff --git a/src/merkle/node.rs b/src/merkle/node.rs index 8440af8..4305e7f 100644 --- a/src/merkle/node.rs +++ b/src/merkle/node.rs @@ -2,6 +2,7 @@ use crate::hash::rpo::RpoDigest; /// Representation of a node with two children used for iterating over containers. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct InnerNodeInfo { pub value: RpoDigest, pub left: RpoDigest, diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index a615e18..1231a04 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -28,6 +28,7 @@ const EMPTY_DIGEST: RpoDigest = RpoDigest::new([ZERO; 4]); /// /// The root of the tree is recomputed on each new leaf update. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct PartialMerkleTree { max_depth: u8, nodes: BTreeMap, diff --git a/src/merkle/path.rs b/src/merkle/path.rs index 975bc68..86f66e4 100644 --- a/src/merkle/path.rs +++ b/src/merkle/path.rs @@ -6,6 +6,7 @@ use core::ops::{Deref, DerefMut}; /// A merkle path container, composed of a sequence of nodes of a Merkle tree. #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerklePath { nodes: Vec, } diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 542ab51..d9e80c1 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -13,6 +13,7 @@ mod tests; /// /// The root of the tree is recomputed on each new leaf update. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct SimpleSmt { depth: u8, root: RpoDigest, @@ -265,6 +266,7 @@ impl SimpleSmt { // ================================================================================================ #[derive(Debug, Default, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] struct BranchNode { left: RpoDigest, right: RpoDigest, diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 8d8b80a..8c23618 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -19,6 +19,7 @@ pub type DefaultMerkleStore = MerkleStore>; pub type RecordingMerkleStore = MerkleStore>; #[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct StoreNode { left: RpoDigest, right: RpoDigest, @@ -87,6 +88,7 @@ pub struct StoreNode { /// assert_eq!(store.num_internal_nodes() - 255, 10); /// ``` #[derive(Debug, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct MerkleStore = BTreeMap> { nodes: T, } diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index fbdf2d3..2cb8792 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -43,6 +43,7 @@ mod tests; /// - Leaf node at depths 16, 32, or 64: hash(key, value, domain=depth). /// - Leaf node at depth 64: hash([key_0, value_0, ..., key_n, value_n], domain=64). #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct TieredSmt { root: RpoDigest, nodes: NodeStore, diff --git a/src/merkle/tiered_smt/nodes.rs b/src/merkle/tiered_smt/nodes.rs index 0d94091..1bb34df 100644 --- a/src/merkle/tiered_smt/nodes.rs +++ b/src/merkle/tiered_smt/nodes.rs @@ -24,6 +24,7 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// represent leaf nodes in a Tiered Sparse Merkle tree. In the current implementation, [BTreeSet]s /// are used to determine the position of the leaves in the tree. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct NodeStore { nodes: BTreeMap, upper_leaves: BTreeSet, diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs index ec2a465..d41ee6b 100644 --- a/src/merkle/tiered_smt/values.rs +++ b/src/merkle/tiered_smt/values.rs @@ -26,6 +26,7 @@ const MAX_DEPTH: u8 = super::TieredSmt::MAX_DEPTH; /// The store supports lookup by the full key (i.e. [RpoDigest]) as well as by the 64-bit key /// prefix. #[derive(Debug, Default, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct ValueStore { values: BTreeMap, } @@ -173,6 +174,7 @@ impl ValueStore { /// An entry can contain either a single key-value pair or a vector of key-value pairs sorted by /// key. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub enum StoreEntry { Single((RpoDigest, Word)), List(Vec<(RpoDigest, Word)>), diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d71cd33..8aadabe 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,5 +1,5 @@ use super::{utils::string::String, Word}; -use core::fmt::{self, Write}; +use core::fmt::{self, Display, Write}; #[cfg(not(feature = "std"))] pub use alloc::{format, vec}; @@ -36,3 +36,78 @@ pub fn word_to_hex(w: &Word) -> Result { Ok(s) } + +/// Renders an array of bytes as hex into a String. +pub fn bytes_to_hex_string(data: [u8; N]) -> String { + let mut s = String::with_capacity(N + 2); + + s.push_str("0x"); + for byte in data.iter() { + write!(s, "{byte:02x}").expect("formatting hex failed"); + } + + s +} + +#[derive(Debug)] +pub enum HexParseError { + InvalidLength { expected: usize, got: usize }, + MissingPrefix, + InvalidChar, + OutOfRange, +} + +impl Display for HexParseError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + HexParseError::InvalidLength { expected, got } => { + write!(f, "Hex encoded RpoDigest must have length 66, including the 0x prefix. expected {expected} got {got}") + } + HexParseError::MissingPrefix => { + write!(f, "Hex encoded RpoDigest must start with 0x prefix") + } + HexParseError::InvalidChar => { + write!(f, "Hex encoded RpoDigest must contain characters [a-zA-Z0-9]") + } + HexParseError::OutOfRange => { + write!(f, "Hex encoded values of an RpoDigest must be inside the field modulus") + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for HexParseError {} + +/// Parses a hex string into an array of bytes of known size. +pub fn hex_to_bytes(value: &str) -> Result<[u8; N], HexParseError> { + let expected: usize = (N * 2) + 2; + if value.len() != expected { + return Err(HexParseError::InvalidLength { + expected, + got: value.len(), + }); + } + + if !value.starts_with("0x") { + return Err(HexParseError::MissingPrefix); + } + + let mut data = value.bytes().skip(2).map(|v| match v { + b'0'..=b'9' => Ok(v - b'0'), + b'a'..=b'f' => Ok(v - b'a' + 10), + b'A'..=b'F' => Ok(v - b'A' + 10), + _ => Err(HexParseError::InvalidChar), + }); + + let mut decoded = [0u8; N]; + #[allow(clippy::needless_range_loop)] + for pos in 0..N { + // These `unwrap` calls are okay because the length was checked above + let high: u8 = data.next().unwrap()?; + let low: u8 = data.next().unwrap()?; + decoded[pos] = (high << 4) + low; + } + + Ok(decoded) +} From fb649df1e769ae1b26fd9ccd76005cecf8ce60fa Mon Sep 17 00:00:00 2001 From: tohrnii <100405913+tohrnii@users.noreply.github.com> Date: Fri, 11 Aug 2023 20:09:34 +0000 Subject: [PATCH 18/32] feat: derive ord and partialord for blake3digest --- src/hash/blake/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hash/blake/mod.rs b/src/hash/blake/mod.rs index 9f02eec..16af67e 100644 --- a/src/hash/blake/mod.rs +++ b/src/hash/blake/mod.rs @@ -26,7 +26,7 @@ const DIGEST20_BYTES: usize = 20; /// /// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32 /// bytes. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] pub struct Blake3Digest([u8; N]); From 6d0c7567f056dd894b361ec8166f333fdf748d87 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 12 Aug 2023 09:59:02 -0700 Subject: [PATCH 19/32] chore: minor code organization improvement --- PQClean | 1 + src/hash/rpo/digest.rs | 128 +++++++++++++++++++++++------------------ 2 files changed, 72 insertions(+), 57 deletions(-) create mode 160000 PQClean diff --git a/PQClean b/PQClean new file mode 160000 index 0000000..c3abebf --- /dev/null +++ b/PQClean @@ -0,0 +1 @@ +Subproject commit c3abebf4ab1ff516ffa71e6337f06d898952c299 diff --git a/src/hash/rpo/digest.rs b/src/hash/rpo/digest.rs index efeda3f..18071ae 100644 --- a/src/hash/rpo/digest.rs +++ b/src/hash/rpo/digest.rs @@ -50,35 +50,54 @@ impl Digest for RpoDigest { } } -impl Serializable for RpoDigest { - fn write_into(&self, target: &mut W) { - target.write_bytes(&self.as_bytes()); +impl Deref for RpoDigest { + type Target = [Felt; DIGEST_SIZE]; + + fn deref(&self) -> &Self::Target { + &self.0 } } -impl Deserializable for RpoDigest { - fn read_from(source: &mut R) -> Result { - let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE]; - for inner in inner.iter_mut() { - let e = source.read_u64()?; - if e >= Felt::MODULUS { - return Err(DeserializationError::InvalidValue(String::from( - "Value not in the appropriate range", - ))); - } - *inner = Felt::new(e); - } - - Ok(Self(inner)) +impl Ord for RpoDigest { + fn cmp(&self, other: &Self) -> Ordering { + // compare the inner u64 of both elements. + // + // it will iterate the elements and will return the first computation different than + // `Equal`. Otherwise, the ordering is equal. + // + // the endianness is irrelevant here because since, this being a cryptographically secure + // hash computation, the digest shouldn't have any ordered property of its input. + // + // finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a + // montgomery reduction for every limb. that is safe because every inner element of the + // digest is guaranteed to be in its canonical form (that is, `x in [0,p)`). + self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold( + Ordering::Equal, + |ord, (a, b)| match ord { + Ordering::Equal => a.cmp(&b), + _ => ord, + }, + ) } } -impl From<[Felt; DIGEST_SIZE]> for RpoDigest { - fn from(value: [Felt; DIGEST_SIZE]) -> Self { - Self(value) +impl PartialOrd for RpoDigest { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } +impl Display for RpoDigest { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let encoded: String = self.into(); + write!(f, "{}", encoded)?; + Ok(()) + } +} + +// CONVERSIONS: FROM RPO DIGEST +// ================================================================================================ + impl From<&RpoDigest> for [Felt; DIGEST_SIZE] { fn from(value: &RpoDigest) -> Self { value.0 @@ -126,17 +145,28 @@ impl From for [u8; DIGEST_BYTES] { } impl From for String { + /// The returned string starts with `0x`. fn from(value: RpoDigest) -> Self { bytes_to_hex_string(value.as_bytes()) } } impl From<&RpoDigest> for String { + /// The returned string starts with `0x`. fn from(value: &RpoDigest) -> Self { (*value).into() } } +// CONVERSIONS: TO DIGEST +// ================================================================================================ + +impl From<[Felt; DIGEST_SIZE]> for RpoDigest { + fn from(value: [Felt; DIGEST_SIZE]) -> Self { + Self(value) + } +} + impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest { type Error = HexParseError; @@ -159,6 +189,7 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest { impl TryFrom<&str> for RpoDigest { type Error = HexParseError; + /// Expects the string to start with `0x`. fn try_from(value: &str) -> Result { hex_to_bytes(value).and_then(|v| v.try_into()) } @@ -167,6 +198,7 @@ impl TryFrom<&str> for RpoDigest { impl TryFrom for RpoDigest { type Error = HexParseError; + /// Expects the string to start with `0x`. fn try_from(value: String) -> Result { value.as_str().try_into() } @@ -175,53 +207,35 @@ impl TryFrom for RpoDigest { impl TryFrom<&String> for RpoDigest { type Error = HexParseError; + /// Expects the string to start with `0x`. fn try_from(value: &String) -> Result { value.as_str().try_into() } } -impl Deref for RpoDigest { - type Target = [Felt; DIGEST_SIZE]; +// SERIALIZATION / DESERIALIZATION +// ================================================================================================ - fn deref(&self) -> &Self::Target { - &self.0 +impl Serializable for RpoDigest { + fn write_into(&self, target: &mut W) { + target.write_bytes(&self.as_bytes()); } } -impl Ord for RpoDigest { - fn cmp(&self, other: &Self) -> Ordering { - // compare the inner u64 of both elements. - // - // it will iterate the elements and will return the first computation different than - // `Equal`. Otherwise, the ordering is equal. - // - // the endianness is irrelevant here because since, this being a cryptographically secure - // hash computation, the digest shouldn't have any ordered property of its input. - // - // finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a - // montgomery reduction for every limb. that is safe because every inner element of the - // digest is guaranteed to be in its canonical form (that is, `x in [0,p)`). - self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold( - Ordering::Equal, - |ord, (a, b)| match ord { - Ordering::Equal => a.cmp(&b), - _ => ord, - }, - ) - } -} +impl Deserializable for RpoDigest { + fn read_from(source: &mut R) -> Result { + let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE]; + for inner in inner.iter_mut() { + let e = source.read_u64()?; + if e >= Felt::MODULUS { + return Err(DeserializationError::InvalidValue(String::from( + "Value not in the appropriate range", + ))); + } + *inner = Felt::new(e); + } -impl PartialOrd for RpoDigest { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Display for RpoDigest { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let encoded: String = self.into(); - write!(f, "{}", encoded)?; - Ok(()) + Ok(Self(inner)) } } From 7780a50dadf2f4becdc121470b63d581d00b6524 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Sat, 12 Aug 2023 21:31:31 -0700 Subject: [PATCH 20/32] fix: remove PQClean submodule --- PQClean | 1 - 1 file changed, 1 deletion(-) delete mode 160000 PQClean diff --git a/PQClean b/PQClean deleted file mode 160000 index c3abebf..0000000 --- a/PQClean +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c3abebf4ab1ff516ffa71e6337f06d898952c299 From f7e6922bffd9c768979aff3bea30d76ce3a22c33 Mon Sep 17 00:00:00 2001 From: "Augusto F. Hack" Date: Tue, 15 Aug 2023 13:53:43 +0200 Subject: [PATCH 21/32] error: moved to its own module --- src/merkle/error.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++ src/merkle/mod.rs | 54 ++------------------------------------------- 2 files changed, 56 insertions(+), 52 deletions(-) create mode 100644 src/merkle/error.rs diff --git a/src/merkle/error.rs b/src/merkle/error.rs new file mode 100644 index 0000000..5012b75 --- /dev/null +++ b/src/merkle/error.rs @@ -0,0 +1,54 @@ +use crate::{ + merkle::{MerklePath, NodeIndex, RpoDigest}, + utils::collections::Vec, +}; +use core::fmt; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MerkleError { + ConflictingRoots(Vec), + DepthTooSmall(u8), + DepthTooBig(u64), + DuplicateValuesForIndex(u64), + DuplicateValuesForKey(RpoDigest), + InvalidIndex { depth: u8, value: u64 }, + InvalidDepth { expected: u8, provided: u8 }, + InvalidPath(MerklePath), + InvalidNumEntries(usize, usize), + NodeNotInSet(NodeIndex), + NodeNotInStore(RpoDigest, NodeIndex), + NumLeavesNotPowerOfTwo(usize), + RootNotInStore(RpoDigest), +} + +impl fmt::Display for MerkleError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use MerkleError::*; + match self { + ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"), + DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"), + DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"), + DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"), + DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"), + InvalidIndex{ depth, value} => write!( + f, + "the index value {value} is not valid for the depth {depth}" + ), + InvalidDepth { expected, provided } => write!( + f, + "the provided depth {provided} is not valid for {expected}" + ), + InvalidPath(_path) => write!(f, "the provided path is not valid"), + InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"), + NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"), + NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"), + NumLeavesNotPowerOfTwo(leaves) => { + write!(f, "the leaves count {leaves} is not a power of 2") + } + RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for MerkleError {} diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 1e687e7..4d9bbeb 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -3,7 +3,6 @@ use super::{ utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec}, Felt, StarkField, Word, WORD_SIZE, ZERO, }; -use core::fmt; // REEXPORTS // ================================================================================================ @@ -41,57 +40,8 @@ pub use node::InnerNodeInfo; mod partial_mt; pub use partial_mt::PartialMerkleTree; -// ERRORS -// ================================================================================================ - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum MerkleError { - ConflictingRoots(Vec), - DepthTooSmall(u8), - DepthTooBig(u64), - DuplicateValuesForIndex(u64), - DuplicateValuesForKey(RpoDigest), - InvalidIndex { depth: u8, value: u64 }, - InvalidDepth { expected: u8, provided: u8 }, - InvalidPath(MerklePath), - InvalidNumEntries(usize, usize), - NodeNotInSet(NodeIndex), - NodeNotInStore(RpoDigest, NodeIndex), - NumLeavesNotPowerOfTwo(usize), - RootNotInStore(RpoDigest), -} - -impl fmt::Display for MerkleError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use MerkleError::*; - match self { - ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"), - DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"), - DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"), - DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"), - DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"), - InvalidIndex{ depth, value} => write!( - f, - "the index value {value} is not valid for the depth {depth}" - ), - InvalidDepth { expected, provided } => write!( - f, - "the provided depth {provided} is not valid for {expected}" - ), - InvalidPath(_path) => write!(f, "the provided path is not valid"), - InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"), - NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"), - NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"), - NumLeavesNotPowerOfTwo(leaves) => { - write!(f, "the leaves count {leaves} is not a power of 2") - } - RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root), - } - } -} - -#[cfg(feature = "std")] -impl std::error::Error for MerkleError {} +mod error; +pub use error::MerkleError; // HELPER FUNCTIONS // ================================================================================================ From 2214ff2425b05e95f6ba41f1ec164afbd1eb46c1 Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Mon, 14 Aug 2023 12:34:14 +0200 Subject: [PATCH 22/32] chore: TSMT benchmark --- Cargo.toml | 11 ++- src/hash/rpo/digest.rs | 14 +++ src/main.rs | 164 +++++++++++++++++++++++++++++++++++ src/merkle/tiered_smt/mod.rs | 6 ++ 4 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 3a863d8..04dc990 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,13 @@ keywords = ["miden", "crypto", "hash", "merkle"] edition = "2021" rust-version = "1.67" +[[bin]] +name = "miden-crypto" +path = "src/main.rs" +bench = false +doctest = false +required-features = ["std"] + [[bench]] name = "hash" harness = false @@ -26,15 +33,17 @@ harness = false [features] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] -std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std"] +std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std", "rand_utils"] serde = ["winter_math/serde", "dep:serde", "serde/alloc"] [dependencies] blake3 = { version = "1.4", default-features = false } +clap = { version = "4.3.21", features = ["derive"] } winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } serde = { version = "1.0", features = [ "derive" ], optional = true, default-features = false } +rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/hash/rpo/digest.rs b/src/hash/rpo/digest.rs index 18071ae..2a269d6 100644 --- a/src/hash/rpo/digest.rs +++ b/src/hash/rpo/digest.rs @@ -4,6 +4,7 @@ use crate::utils::{ DeserializationError, HexParseError, Serializable, }; use core::{cmp::Ordering, fmt::Display, ops::Deref}; +use winter_utils::Randomizable; /// The number of bytes needed to encoded a digest pub const DIGEST_BYTES: usize = 32; @@ -95,6 +96,19 @@ impl Display for RpoDigest { } } +impl Randomizable for RpoDigest { + const VALUE_SIZE: usize = DIGEST_BYTES; + + fn from_random_bytes(bytes: &[u8]) -> Option { + let bytes_array: Option<[u8; 32]> = bytes.try_into().ok(); + if let Some(bytes_array) = bytes_array { + Self::try_from(bytes_array).ok() + } else { + None + } + } +} + // CONVERSIONS: FROM RPO DIGEST // ================================================================================================ diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..8700dec --- /dev/null +++ b/src/main.rs @@ -0,0 +1,164 @@ +use clap::Parser; +use miden_crypto::{ + hash::rpo::RpoDigest, + merkle::MerkleError, + Felt, Word, ONE, + {hash::rpo::Rpo256, merkle::TieredSmt}, +}; +use rand_utils::rand_value; +use std::time::Instant; + +#[derive(Parser, Debug)] +#[clap( + name = "Benchmark", + about = "Tiered SMT benchmark", + version, + rename_all = "kebab-case" +)] +pub struct BenchmarkCmd { + /// Size of the tree + #[clap(short = 's', long = "size")] + size: u64, + + /// Run the construction benchmark + #[clap(short = 'c', long = "construction")] + construction: bool, + + /// Run the insertion benchmark + #[clap(short = 'i', long = "insertion")] + insertion: bool, + + /// Run the proof generation benchmark + #[clap(short = 'p', long = "proof-generation")] + proof_generation: bool, +} + +fn main() { + benchmark_tsmt(); +} + +/// Run a benchmark for the Tiered SMT. +pub fn benchmark_tsmt() { + let args = BenchmarkCmd::parse(); + let tree_size = args.size; + + // prepare the `leaves` vector for tree creation + let mut leaves = Vec::new(); + for i in 0..tree_size { + let key = rand_value::(); + let value = [ONE, ONE, ONE, Felt::new(i)]; + leaves.push((key, value)); + } + + let mut tree: Option = None; + + // if the `-c` argument was specified + if args.construction { + tree = Some(construction(leaves.clone(), tree_size).unwrap()); + } + + // if the `-i` argument was specified + if args.insertion { + if let Some(inner_tree) = tree { + tree = Some(insertion(inner_tree, tree_size).unwrap()); + } else { + let inner_tree = TieredSmt::with_leaves(leaves.clone()).unwrap(); + tree = Some(insertion(inner_tree, tree_size).unwrap()); + } + } + + // if the `-p` argument was specified + if args.proof_generation { + if let Some(inner_tree) = tree { + proof_generation(inner_tree, tree_size).unwrap(); + } else { + let inner_tree = TieredSmt::with_leaves(leaves).unwrap(); + proof_generation(inner_tree, tree_size).unwrap(); + } + } +} + +/// Run the construction benchmark for the Tiered SMT. +pub fn construction(leaves: Vec<(RpoDigest, Word)>, size: u64) -> Result { + println!("Running a construction benchmark:"); + let now = Instant::now(); + let tree = TieredSmt::with_leaves(leaves)?; + let elapsed = now.elapsed(); + println!( + "Constructed a TSMT with {} key-value pairs in {:.3} seconds", + size, + elapsed.as_secs_f32(), + ); + + // Count how many nodes end up at each tier + let mut nodes_num_16_32_48 = (0, 0, 0); + + tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() { + 16 => nodes_num_16_32_48.0 += 1, + 32 => nodes_num_16_32_48.1 += 1, + 48 => nodes_num_16_32_48.2 += 1, + _ => unreachable!(), + }); + + println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0); + println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1); + println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2); + println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count()); + + Ok(tree) +} + +/// Run the insertion benchmark for the Tiered SMT. +pub fn insertion(mut tree: TieredSmt, size: u64) -> Result { + println!("Running an insertion benchmark:"); + + let mut insertion_times = Vec::new(); + + for i in 0..20 { + let test_key = Rpo256::hash(&rand_value::().to_be_bytes()); + let test_value = [ONE, ONE, ONE, Felt::new(size + i)]; + + let now = Instant::now(); + tree.insert(test_key, test_value); + let elapsed = now.elapsed(); + insertion_times.push(elapsed.as_secs_f32()); + } + + println!( + "An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n", + size, + // calculate the average by dividing by 20 and convert to milliseconds by multiplying by + // 1000. As a result, we can only multiply by 50 + insertion_times.iter().sum::() * 50f32, + ); + + Ok(tree) +} + +/// Run the proof generation benchmark for the Tiered SMT. +pub fn proof_generation(mut tree: TieredSmt, size: u64) -> Result<(), MerkleError> { + println!("Running a proof generation benchmark:"); + + let mut insertion_times = Vec::new(); + + for i in 0..20 { + let test_key = Rpo256::hash(&rand_value::().to_be_bytes()); + let test_value = [ONE, ONE, ONE, Felt::new(size + i)]; + tree.insert(test_key, test_value); + + let now = Instant::now(); + let _proof = tree.prove(test_key); + let elapsed = now.elapsed(); + insertion_times.push(elapsed.as_secs_f32()); + } + + println!( + "An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds", + size, + // calculate the average by dividing by 20 and convert to microseconds by multiplying by + // 1000000. As a result, we can only multiply by 50000 + insertion_times.iter().sum::() * 50000f32, + ); + + Ok(()) +} diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 2cb8792..d2cc529 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -274,6 +274,12 @@ impl TieredSmt { }) } + /// Returns an iterator over upper leaves (i.e., depth = 16, 32, or 48) for this [TieredSmt] + /// where each yielded item is a (node_index, value) tuple. + pub fn upper_leaf_nodes(&self) -> impl Iterator { + self.nodes.upper_leaves() + } + /// Returns an iterator over bottom leaves (i.e., depth = 64) of this [TieredSmt]. /// /// Each yielded item consists of the hash of the leaf and its contents, where contents is From c1d061211501dffe3267e15ab23885bee157670d Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Thu, 17 Aug 2023 21:50:01 +0200 Subject: [PATCH 23/32] refactor: run all benchmarks at once, leave only size run option --- src/main.rs | 53 +++++++++-------------------------------------------- 1 file changed, 9 insertions(+), 44 deletions(-) diff --git a/src/main.rs b/src/main.rs index 8700dec..800306c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,18 +19,6 @@ pub struct BenchmarkCmd { /// Size of the tree #[clap(short = 's', long = "size")] size: u64, - - /// Run the construction benchmark - #[clap(short = 'c', long = "construction")] - construction: bool, - - /// Run the insertion benchmark - #[clap(short = 'i', long = "insertion")] - insertion: bool, - - /// Run the proof generation benchmark - #[clap(short = 'p', long = "proof-generation")] - proof_generation: bool, } fn main() { @@ -50,35 +38,12 @@ pub fn benchmark_tsmt() { leaves.push((key, value)); } - let mut tree: Option = None; - - // if the `-c` argument was specified - if args.construction { - tree = Some(construction(leaves.clone(), tree_size).unwrap()); - } - - // if the `-i` argument was specified - if args.insertion { - if let Some(inner_tree) = tree { - tree = Some(insertion(inner_tree, tree_size).unwrap()); - } else { - let inner_tree = TieredSmt::with_leaves(leaves.clone()).unwrap(); - tree = Some(insertion(inner_tree, tree_size).unwrap()); - } - } - - // if the `-p` argument was specified - if args.proof_generation { - if let Some(inner_tree) = tree { - proof_generation(inner_tree, tree_size).unwrap(); - } else { - let inner_tree = TieredSmt::with_leaves(leaves).unwrap(); - proof_generation(inner_tree, tree_size).unwrap(); - } - } + let mut tree = construction(leaves, tree_size).unwrap(); + insertion(&mut tree, tree_size).unwrap(); + proof_generation(&mut tree, tree_size).unwrap(); } -/// Run the construction benchmark for the Tiered SMT. +/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree. pub fn construction(leaves: Vec<(RpoDigest, Word)>, size: u64) -> Result { println!("Running a construction benchmark:"); let now = Instant::now(); @@ -108,8 +73,8 @@ pub fn construction(leaves: Vec<(RpoDigest, Word)>, size: u64) -> Result Result { +/// Runs the insertion benchmark for the Tiered SMT. +pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> { println!("Running an insertion benchmark:"); let mut insertion_times = Vec::new(); @@ -132,11 +97,11 @@ pub fn insertion(mut tree: TieredSmt, size: u64) -> Result() * 50f32, ); - Ok(tree) + Ok(()) } -/// Run the proof generation benchmark for the Tiered SMT. -pub fn proof_generation(mut tree: TieredSmt, size: u64) -> Result<(), MerkleError> { +/// Runs the proof generation benchmark for the Tiered SMT. +pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> { println!("Running a proof generation benchmark:"); let mut insertion_times = Vec::new(); From 9f54c82d622abfc4c0926238d0a3adc8766366a0 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Wed, 16 Aug 2023 02:54:11 -0700 Subject: [PATCH 24/32] feat: implement additional leaf traversal methods on MerkleStore --- CHANGELOG.md | 5 ++ src/merkle/path.rs | 10 ++++ src/merkle/store/mod.rs | 108 ++++++++++++++++++++++++++--------- src/merkle/store/tests.rs | 64 +++++++++++++++++++++ src/merkle/tiered_smt/mod.rs | 6 +- 5 files changed, 163 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f833a3..9ec4658 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ ## 0.7.0 (TBD) * Replaced `MerklePathSet` with `PartialMerkleTree` (#165). +* Implemented clearing of nodes in `TieredSmt` (#173). +* Added ability to generate inclusion proofs for `TieredSmt` (#174). +* Added conditional `serde`` support for various structs (#180). +* Implemented benchmarking for `TieredSmt` (#182). +* Added more leaf traversal methods for `MerkleStore` (#185). ## 0.6.0 (2023-06-25) diff --git a/src/merkle/path.rs b/src/merkle/path.rs index 86f66e4..2d11bb3 100644 --- a/src/merkle/path.rs +++ b/src/merkle/path.rs @@ -160,6 +160,16 @@ pub struct ValuePath { pub path: MerklePath, } +impl ValuePath { + /// Returns a new [ValuePath] instantiated from the specified value and path. + pub fn new(value: RpoDigest, path: Vec) -> Self { + Self { + value, + path: MerklePath::new(path), + } + } +} + /// A container for a [MerklePath] and its [Word] root. /// /// This structure does not provide any guarantees regarding the correctness of the path to the diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 8c23618..5479b0f 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -173,27 +173,24 @@ impl> MerkleStore { // the path is computed from root to leaf, so it must be reversed path.reverse(); - Ok(ValuePath { - value: hash, - path: MerklePath::new(path), - }) + Ok(ValuePath::new(hash, path)) } - /// Reconstructs a path from the root until a leaf or empty node and returns its depth. + // LEAF TRAVERSAL + // -------------------------------------------------------------------------------------------- + + /// Returns the depth of the first leaf or an empty node encountered while traversing the tree + /// from the specified root down according to the provided index. /// - /// The `tree_depth` parameter defines up to which depth the tree will be traversed, starting - /// from `root`. The maximum value the argument accepts is [u64::BITS]. - /// - /// The traversed path from leaf to root will start at the least significant bit of `index`, - /// and will be executed for `tree_depth` bits. + /// The `tree_depth` parameter specifies the depth of the tree rooted at `root`. The + /// maximum value the argument accepts is [u64::BITS]. /// /// # Errors /// Will return an error if: /// - The provided root is not found. - /// - The path from the root continues to a depth greater than `tree_depth`. - /// - The provided `tree_depth` is greater than `64. - /// - The provided `index` is not valid for a depth equivalent to `tree_depth`. For more - /// information, check [NodeIndex::new]. + /// - The provided `tree_depth` is greater than 64. + /// - The provided `index` is not valid for a depth equivalent to `tree_depth`. + /// - No leaf or an empty node was found while traversing the tree down to `tree_depth`. pub fn get_leaf_depth( &self, root: RpoDigest, @@ -206,13 +203,6 @@ impl> MerkleStore { } NodeIndex::new(tree_depth, index)?; - // it's not illegal to have a maximum depth of `0`; we should just return the root in that - // case. this check will simplify the implementation as we could overflow bits for depth - // `0`. - if tree_depth == 0 { - return Ok(0); - } - // check if the root exists, providing the proper error report if it doesn't let empty = EmptySubtreeRoots::empty_hashes(tree_depth); let mut hash = root; @@ -224,7 +214,7 @@ impl> MerkleStore { let mut path = (index << (64 - tree_depth)).reverse_bits(); // iterate every depth and reconstruct the path from root to leaf - for depth in 0..tree_depth { + for depth in 0..=tree_depth { // we short-circuit if an empty node has been found if hash == empty[depth as usize] { return Ok(depth); @@ -241,13 +231,77 @@ impl> MerkleStore { path >>= 1; } - // at max depth assert it doesn't have sub-trees - if self.nodes.contains_key(&hash) { - return Err(MerkleError::DepthTooBig(tree_depth as u64 + 1)); + // return an error because we exhausted the index but didn't find either a leaf or an + // empty node + Err(MerkleError::DepthTooBig(tree_depth as u64 + 1)) + } + + /// Returns index and value of a leaf node which is the only leaf node in a subtree defined by + /// the provided root. If the subtree contains zero or more than one leaf nodes None is + /// returned. + /// + /// The `tree_depth` parameter specifies the depth of the parent tree such that `root` is + /// located in this tree at `root_index`. The maximum value the argument accepts is + /// [u64::BITS]. + /// + /// # Errors + /// Will return an error if: + /// - The provided root is not found. + /// - The provided `tree_depth` is greater than 64. + /// - The provided `root_index` has depth greater than `tree_depth`. + /// - A lone node at depth `tree_depth` is not a leaf node. + pub fn find_lone_leaf( + &self, + root: RpoDigest, + root_index: NodeIndex, + tree_depth: u8, + ) -> Result, MerkleError> { + // we set max depth at u64::BITS as this is the largest meaningful value for a 64-bit index + const MAX_DEPTH: u8 = u64::BITS as u8; + if tree_depth > MAX_DEPTH { + return Err(MerkleError::DepthTooBig(tree_depth as u64)); + } + let empty = EmptySubtreeRoots::empty_hashes(MAX_DEPTH); + + let mut node = root; + if !self.nodes.contains_key(&node) { + return Err(MerkleError::RootNotInStore(node)); } - // depleted bits; return max depth - Ok(tree_depth) + let mut index = root_index; + if index.depth() > tree_depth { + return Err(MerkleError::DepthTooBig(index.depth() as u64)); + } + + // traverse down following the path of single non-empty nodes; this works because if a + // node has two empty children it cannot contain a lone leaf. similarly if a node has + // two non-empty children it must contain at least two leaves. + for depth in index.depth()..tree_depth { + // if the node is a leaf, return; otherwise, examine the node's children + let children = match self.nodes.get(&node) { + Some(node) => node, + None => return Ok(Some((index, node))), + }; + + let empty_node = empty[depth as usize + 1]; + node = if children.left != empty_node && children.right == empty_node { + index = index.left_child(); + children.left + } else if children.left == empty_node && children.right != empty_node { + index = index.right_child(); + children.right + } else { + return Ok(None); + }; + } + + // if we are here, we got to `tree_depth`; thus, either the current node is a leaf node, + // and so we return it, or it is an internal node, and then we return an error + if self.nodes.contains_key(&node) { + Err(MerkleError::DepthTooBig(tree_depth as u64 + 1)) + } else { + Ok(Some((index, node))) + } } // DATA EXTRACTORS diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index dbc071e..daed892 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -637,6 +637,9 @@ fn node_path_should_be_truncated_by_midtier_insert() { assert!(store.get_node(root, index).is_err()); } +// LEAF TRAVERSAL +// ================================================================================================ + #[test] fn get_leaf_depth_works_depth_64() { let mut store = MerkleStore::new(); @@ -747,6 +750,67 @@ fn get_leaf_depth_works_with_depth_8() { assert_eq!(Err(MerkleError::DepthTooBig(9)), store.get_leaf_depth(root, 8, a)); } +#[test] +fn find_lone_leaf() { + let mut store = MerkleStore::new(); + let empty = EmptySubtreeRoots::empty_hashes(64); + let mut root: RpoDigest = empty[0]; + + // insert a single leaf into the store at depth 64 + let key_a = 0b01010101_10101010_00001111_01110100_00111011_10101101_00000100_01000001_u64; + let idx_a = NodeIndex::make(64, key_a); + let val_a = RpoDigest::from([ONE, ONE, ONE, ONE]); + root = store.set_node(root, idx_a, val_a).unwrap().root; + + // for every ancestor of A, A should be a long leaf + for depth in 1..64 { + let parent_index = NodeIndex::make(depth, key_a >> (64 - depth)); + let parent = store.get_node(root, parent_index).unwrap(); + + let res = store.find_lone_leaf(parent, parent_index, 64).unwrap(); + assert_eq!(res, Some((idx_a, val_a))); + } + + // insert another leaf into the store such that it has the same 8 bit prefix as A + let key_b = 0b01010101_01111010_00001111_01110100_00111011_10101101_00000100_01000001_u64; + let idx_b = NodeIndex::make(64, key_b); + let val_b = RpoDigest::from([ONE, ONE, ONE, ZERO]); + root = store.set_node(root, idx_b, val_b).unwrap().root; + + // for any node which is common between A and B, find_lone_leaf() should return None as the + // node has two descendants + for depth in 1..9 { + let parent_index = NodeIndex::make(depth, key_a >> (64 - depth)); + let parent = store.get_node(root, parent_index).unwrap(); + + let res = store.find_lone_leaf(parent, parent_index, 64).unwrap(); + assert_eq!(res, None); + } + + // for other ancestors of A and B, A and B should be lone leaves respectively + for depth in 9..64 { + let parent_index = NodeIndex::make(depth, key_a >> (64 - depth)); + let parent = store.get_node(root, parent_index).unwrap(); + + let res = store.find_lone_leaf(parent, parent_index, 64).unwrap(); + assert_eq!(res, Some((idx_a, val_a))); + } + + for depth in 9..64 { + let parent_index = NodeIndex::make(depth, key_b >> (64 - depth)); + let parent = store.get_node(root, parent_index).unwrap(); + + let res = store.find_lone_leaf(parent, parent_index, 64).unwrap(); + assert_eq!(res, Some((idx_b, val_b))); + } + + // for any other node, find_lone_leaf() should return None as they have no leaf nodes + let parent_index = NodeIndex::make(16, 0b01010101_11111111); + let parent = store.get_node(root, parent_index).unwrap(); + let res = store.find_lone_leaf(parent, parent_index, 64).unwrap(); + assert_eq!(res, None); +} + // SUBSET EXTRACTION // ================================================================================================ diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 2cb8792..cf2a26f 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -55,13 +55,13 @@ impl TieredSmt { // -------------------------------------------------------------------------------------------- /// The number of levels between tiers. - const TIER_SIZE: u8 = 16; + pub const TIER_SIZE: u8 = 16; /// Depths at which leaves can exist in a tiered SMT. - const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64]; + pub const TIER_DEPTHS: [u8; 4] = [16, 32, 48, 64]; /// Maximum node depth. This is also the bottom tier of the tree. - const MAX_DEPTH: u8 = 64; + pub const MAX_DEPTH: u8 = 64; /// Value of an empty leaf. pub const EMPTY_VALUE: Word = super::empty_roots::EMPTY_WORD; From 2f09410e87ca13f725b736832a91887100e786ce Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Thu, 31 Aug 2023 16:22:03 +0200 Subject: [PATCH 25/32] refactor: replace with EMPTY_WORD, ZERO and ONE --- src/hash/rpo/mds_freq.rs | 8 ++++---- src/hash/rpo/tests.rs | 10 +++++----- src/lib.rs | 3 +++ src/merkle/delta.rs | 4 ++-- src/merkle/empty_roots.rs | 10 ++-------- src/merkle/mod.rs | 2 +- src/merkle/partial_mt/mod.rs | 4 ++-- src/merkle/simple_smt/mod.rs | 2 +- src/merkle/simple_smt/tests.rs | 4 ++-- src/merkle/store/mod.rs | 6 +++--- src/merkle/store/tests.rs | 4 ++-- src/merkle/tiered_smt/mod.rs | 2 +- src/merkle/tiered_smt/tests.rs | 36 +++++++++++++++++----------------- 13 files changed, 46 insertions(+), 49 deletions(-) diff --git a/src/hash/rpo/mds_freq.rs b/src/hash/rpo/mds_freq.rs index ed4b449..6d1f1fd 100644 --- a/src/hash/rpo/mds_freq.rs +++ b/src/hash/rpo/mds_freq.rs @@ -156,14 +156,14 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { #[cfg(test)] mod tests { - use super::super::{Felt, FieldElement, Rpo256, MDS}; + use super::super::{Felt, Rpo256, MDS, ZERO}; use proptest::prelude::*; const STATE_WIDTH: usize = 12; #[inline(always)] fn apply_mds_naive(state: &mut [Felt; STATE_WIDTH]) { - let mut result = [Felt::ZERO; STATE_WIDTH]; + let mut result = [ZERO; STATE_WIDTH]; result.iter_mut().zip(MDS).for_each(|(r, mds_row)| { state.iter().zip(mds_row).for_each(|(&s, m)| { *r += m * s; @@ -174,9 +174,9 @@ mod tests { proptest! { #[test] - fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) { + fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) { - let mut v1 = [Felt::ZERO;STATE_WIDTH]; + let mut v1 = [ZERO; STATE_WIDTH]; let mut v2; for i in 0..STATE_WIDTH { diff --git a/src/hash/rpo/tests.rs b/src/hash/rpo/tests.rs index d0f6889..3ca5a33 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rpo/tests.rs @@ -105,7 +105,7 @@ fn hash_elements_vs_merge_with_int() { let mut elements = seed.as_elements().to_vec(); elements.push(Felt::new(val)); - elements.push(Felt::new(1)); + elements.push(ONE); let h_result = Rpo256::hash_elements(&elements); assert_eq!(m_result, h_result); @@ -147,8 +147,8 @@ fn hash_elements_padding() { #[test] fn hash_elements() { let elements = [ - Felt::new(0), - Felt::new(1), + ZERO, + ONE, Felt::new(2), Felt::new(3), Felt::new(4), @@ -170,8 +170,8 @@ fn hash_elements() { #[test] fn hash_test_vectors() { let elements = [ - Felt::new(0), - Felt::new(1), + ZERO, + ONE, Felt::new(2), Felt::new(3), Felt::new(4), diff --git a/src/lib.rs b/src/lib.rs index 7c7d753..cb0e11e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,9 @@ pub const ZERO: Felt = Felt::ZERO; /// Field element representing ONE in the Miden base filed. pub const ONE: Felt = Felt::ONE; +/// Array of field elements representing word of ZEROs in the Miden base field. +pub const EMPTY_WORD: [Felt; 4] = [ZERO; WORD_SIZE]; + // TESTS // ================================================================================================ diff --git a/src/merkle/delta.rs b/src/merkle/delta.rs index cf6d1b9..064cd01 100644 --- a/src/merkle/delta.rs +++ b/src/merkle/delta.rs @@ -4,7 +4,7 @@ use super::{ use crate::utils::collections::Diff; #[cfg(test)] -use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt}; +use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO}; // MERKLE STORE DELTA // ================================================================================================ @@ -121,7 +121,7 @@ pub struct MerkleTreeDelta { #[test] fn test_compute_merkle_delta() { let entries = vec![ - (10, [Felt::new(0), Felt::new(1), Felt::new(2), Felt::new(3)]), + (10, [ZERO, ONE, Felt::new(2), Felt::new(3)]), (15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]), (20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]), (31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]), diff --git a/src/merkle/empty_roots.rs b/src/merkle/empty_roots.rs index b1b0b30..17cd781 100644 --- a/src/merkle/empty_roots.rs +++ b/src/merkle/empty_roots.rs @@ -1,12 +1,6 @@ -use super::{Felt, RpoDigest, Word, WORD_SIZE, ZERO}; +use super::{Felt, RpoDigest, EMPTY_WORD}; use core::slice; -// CONSTANTS -// ================================================================================================ - -/// A word consisting of 4 ZERO elements. -pub const EMPTY_WORD: Word = [ZERO; WORD_SIZE]; - // EMPTY NODES SUBTREES // ================================================================================================ @@ -1556,7 +1550,7 @@ const EMPTY_SUBTREES: [RpoDigest; 256] = [ Felt::new(0xd3ad9fb0cea61624), Felt::new(0x66ab5c684fbb8597), ]), - RpoDigest::new([ZERO; WORD_SIZE]), + RpoDigest::new(EMPTY_WORD), ]; #[test] diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 4d9bbeb..c4e43ca 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -1,7 +1,7 @@ use super::{ hash::rpo::{Rpo256, RpoDigest}, utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec}, - Felt, StarkField, Word, WORD_SIZE, ZERO, + Felt, StarkField, Word, EMPTY_WORD, ZERO, }; // REEXPORTS diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index 1231a04..2c95972 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -1,6 +1,6 @@ use super::{ BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, - ValuePath, Vec, Word, ZERO, + ValuePath, Vec, Word, EMPTY_WORD, }; use crate::utils::{ format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable, @@ -18,7 +18,7 @@ mod tests; const ROOT_INDEX: NodeIndex = NodeIndex::root(); /// An RpoDigest consisting of 4 ZERO elements. -const EMPTY_DIGEST: RpoDigest = RpoDigest::new([ZERO; 4]); +const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD); // PARTIAL MERKLE TREE // ================================================================================================ diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index d9e80c1..7c40d90 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -33,7 +33,7 @@ impl SimpleSmt { pub const MAX_DEPTH: u8 = 64; /// Value of an empty leaf. - pub const EMPTY_VALUE: Word = super::empty_roots::EMPTY_WORD; + pub const EMPTY_VALUE: Word = super::EMPTY_WORD; // CONSTRUCTORS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index d86f98b..26aaab2 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -1,9 +1,9 @@ use super::{ - super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt}, + super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt, EMPTY_WORD}, NodeIndex, Rpo256, Vec, }; use crate::{ - merkle::{digests_to_words, empty_roots::EMPTY_WORD, int_to_leaf, int_to_node}, + merkle::{digests_to_words, int_to_leaf, int_to_node}, Word, }; diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 5479b0f..39c7548 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -1,7 +1,7 @@ use super::{ - empty_roots::EMPTY_WORD, mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, - MerkleError, MerklePath, MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, - RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, + mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath, + MerkleStoreDelta, MerkleTree, NodeIndex, PartialMerkleTree, RecordingMap, RootPath, Rpo256, + RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec, EMPTY_WORD, }; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use core::borrow::Borrow; diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index daed892..dc32ffd 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -478,7 +478,7 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> { #[test] fn wont_open_to_different_depth_root() { let empty = EmptySubtreeRoots::empty_hashes(64); - let a = [Felt::new(1); 4]; + let a = [ONE; 4]; let b = [Felt::new(2); 4]; // Compute the root for a different depth. We cherry-pick this specific depth to prevent a @@ -501,7 +501,7 @@ fn wont_open_to_different_depth_root() { #[test] fn store_path_opens_from_leaf() { - let a = [Felt::new(1); 4]; + let a = [ONE; 4]; let b = [Felt::new(2); 4]; let c = [Felt::new(3); 4]; let d = [Felt::new(4); 4]; diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 8dc0c7a..7fbb2ca 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -64,7 +64,7 @@ impl TieredSmt { pub const MAX_DEPTH: u8 = 64; /// Value of an empty leaf. - pub const EMPTY_VALUE: Word = super::empty_roots::EMPTY_WORD; + pub const EMPTY_VALUE: Word = super::EMPTY_WORD; // CONSTRUCTORS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index 61e6081..560db47 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -1,5 +1,5 @@ use super::{ - super::{super::ONE, empty_roots::EMPTY_WORD, Felt, MerkleStore, WORD_SIZE, ZERO}, + super::{super::ONE, super::WORD_SIZE, Felt, MerkleStore, EMPTY_WORD, ZERO}, EmptySubtreeRoots, InnerNodeInfo, NodeIndex, Rpo256, RpoDigest, TieredSmt, Vec, Word, }; @@ -279,11 +279,11 @@ fn tsmt_delete_16() { smt.insert(key_b, value_b); // --- delete the last inserted value ------------------------------------- - assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b); assert_eq!(smt, smt1); // --- delete the first inserted value ------------------------------------ - assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a); assert_eq!(smt, smt0); } @@ -313,15 +313,15 @@ fn tsmt_delete_32() { smt.insert(key_c, value_c); // --- delete the last inserted value ------------------------------------- - assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c); assert_eq!(smt, smt2); // --- delete the last inserted value ------------------------------------- - assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b); assert_eq!(smt, smt1); // --- delete the first inserted value ------------------------------------ - assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a); assert_eq!(smt, smt0); } @@ -353,15 +353,15 @@ fn tsmt_delete_48_same_32_bit_prefix() { smt.insert(key_c, value_c); // --- delete the last inserted value ------------------------------------- - assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c); assert_eq!(smt, smt2); // --- delete the last inserted value ------------------------------------- - assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b); assert_eq!(smt, smt1); // --- delete the first inserted value ------------------------------------ - assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a); assert_eq!(smt, smt0); } @@ -400,16 +400,16 @@ fn tsmt_delete_48_mixed_prefix() { smt.insert(key_d, value_d); // --- delete the inserted values one-by-one ------------------------------ - assert_eq!(smt.insert(key_d, [ZERO; 4]), value_d); + assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d); assert_eq!(smt, smt3); - assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c); assert_eq!(smt, smt2); - assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b); assert_eq!(smt, smt1); - assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a); assert_eq!(smt, smt0); } @@ -447,16 +447,16 @@ fn tsmt_delete_64() { smt.insert(key_d, value_d); // --- delete the last inserted value ------------------------------------- - assert_eq!(smt.insert(key_d, [ZERO; 4]), value_d); + assert_eq!(smt.insert(key_d, EMPTY_WORD), value_d); assert_eq!(smt, smt3); - assert_eq!(smt.insert(key_c, [ZERO; 4]), value_c); + assert_eq!(smt.insert(key_c, EMPTY_WORD), value_c); assert_eq!(smt, smt2); - assert_eq!(smt.insert(key_b, [ZERO; 4]), value_b); + assert_eq!(smt.insert(key_b, EMPTY_WORD), value_b); assert_eq!(smt, smt1); - assert_eq!(smt.insert(key_a, [ZERO; 4]), value_a); + assert_eq!(smt.insert(key_a, EMPTY_WORD), value_a); assert_eq!(smt, smt0); } @@ -577,7 +577,7 @@ fn test_order_sensitivity() { smt_1.insert(key_1, value); smt_1.insert(key_2, value); - smt_1.insert(key_2, [ZERO; WORD_SIZE]); + smt_1.insert(key_2, EMPTY_WORD); let mut smt_2 = TieredSmt::default(); smt_2.insert(key_1, value); From 1fa28957245e38ca798072e00586340af8f21a28 Mon Sep 17 00:00:00 2001 From: frisitano Date: Tue, 19 Sep 2023 16:01:52 +0800 Subject: [PATCH 26/32] refactor: modify MerkleStore::non_empty_leaves to support TSMT --- src/merkle/store/mod.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index 39c7548..a8dd1fb 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -346,10 +346,13 @@ impl> MerkleStore { core::iter::from_fn(move || { while let Some((index, node_hash)) = stack.pop() { + // if we are at the max depth then we have reached a leaf if index.depth() == max_depth { return Some((index, node_hash)); } + // fetch the nodes children and push them onto the stack if they are not the roots + // of empty subtrees if let Some(node) = self.nodes.get(&node_hash) { if !empty_roots.contains(&node.left) { stack.push((index.left_child(), node.left)); @@ -357,6 +360,13 @@ impl> MerkleStore { if !empty_roots.contains(&node.right) { stack.push((index.right_child(), node.right)); } + + // if the node is not in the store assume it is a leaf + } else { + // assert that if we have a leaf that is not at the max depth then it must be + // at the depth of one of the tiers of an TSMT. + debug_assert!(TieredSmt::TIER_DEPTHS[..3].contains(&index.depth())); + return Some((index, node_hash)); } } From 701a187e7f7db3a62d3ef82af34504e0b50d20ca Mon Sep 17 00:00:00 2001 From: Grzegorz Swirski Date: Fri, 15 Sep 2023 11:09:03 +0200 Subject: [PATCH 27/32] feat: implement RPO hash using SVE instructionss --- .gitignore | 3 + Cargo.toml | 4 + arch/arm64-sve/CMakeLists.txt | 10 ++ arch/arm64-sve/library.c | 78 ++++++++++++ arch/arm64-sve/library.h | 12 ++ arch/arm64-sve/rpo_hash.h | 221 ++++++++++++++++++++++++++++++++++ arch/arm64-sve/test.c | 27 +++++ build.rs | 17 +++ src/hash/rpo/mod.rs | 68 ++++++++++- 9 files changed, 436 insertions(+), 4 deletions(-) create mode 100644 arch/arm64-sve/CMakeLists.txt create mode 100644 arch/arm64-sve/library.c create mode 100644 arch/arm64-sve/library.h create mode 100644 arch/arm64-sve/rpo_hash.h create mode 100644 arch/arm64-sve/test.c create mode 100644 build.rs diff --git a/.gitignore b/.gitignore index 088ba6b..1f17879 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Generated by cmake +cmake-build-* diff --git a/Cargo.toml b/Cargo.toml index 04dc990..4c92e57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ name = "store" harness = false [features] +arch-arm64-sve = ["dep:cc"] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std", "rand_utils"] serde = ["winter_math/serde", "dep:serde", "serde/alloc"] @@ -49,3 +50,6 @@ rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.1.0" rand_utils = { version = "0.6", package = "winter-rand-utils" } + +[build-dependencies] +cc = { version = "1.0.79", optional = true } diff --git a/arch/arm64-sve/CMakeLists.txt b/arch/arm64-sve/CMakeLists.txt new file mode 100644 index 0000000..40710b1 --- /dev/null +++ b/arch/arm64-sve/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.0) +project(rpo_sve C) + +set(CMAKE_C_STANDARD 23) +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve -Wall -Wextra -pedantic -g -O3") + +add_library(rpo_sve library.c rpo_hash.h) + +add_executable(rpo_test test.c) +target_link_libraries(rpo_test rpo_sve) diff --git a/arch/arm64-sve/library.c b/arch/arm64-sve/library.c new file mode 100644 index 0000000..a1791f7 --- /dev/null +++ b/arch/arm64-sve/library.c @@ -0,0 +1,78 @@ +#include +#include +#include "library.h" +#include "rpo_hash.h" + +// The STATE_WIDTH of RPO hash is 12x u64 elements. +// The current generation of SVE-enabled processors - Neoverse V1 +// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64) +// This allows us to split the state into 3 vectors of 4 elements +// and process all 3 independent of each other. + +// We see the biggest performance gains by leveraging both +// vector and scalar operations on parts of the state array. +// Due to high latency of vector operations, the processor is able +// to reorder and pipeline scalar instructions while we wait for +// vector results. This effectively gives us some 'free' scalar +// operations and masks vector latency. +// +// This also means that we can fully saturate all four arithmetic +// units of the processor (2x scalar, 2x SIMD) +// +// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS +// GAIN WIDER REGISTERS. It's quite possible that with 8x u64 +// vectors processing 2 partially filled vectors might +// be easier and faster than dealing with scalar operations +// on the remainder of the array. +// +// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back +// to the regular, already highly-optimized scalar version +// if the conditions are not met. + +bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { + const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector + + if (vl != 4) { + return false; + } + + svbool_t ptrue = svptrue_b64(); + + svuint64_t state1 = svld1(ptrue, state + 0*vl); + svuint64_t state2 = svld1(ptrue, state + 1*vl); + + svuint64_t const1 = svld1(ptrue, constants + 0*vl); + svuint64_t const2 = svld1(ptrue, constants + 1*vl); + + add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8); + apply_sbox(ptrue, &state1, &state2, state+8); + + svst1(ptrue, state + 0*vl, state1); + svst1(ptrue, state + 1*vl, state2); + + return true; +} + +bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) { + const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector + + if (vl != 4) { + return false; + } + + svbool_t ptrue = svptrue_b64(); + + svuint64_t state1 = svld1(ptrue, state + 0 * vl); + svuint64_t state2 = svld1(ptrue, state + 1 * vl); + + svuint64_t const1 = svld1(ptrue, constants + 0 * vl); + svuint64_t const2 = svld1(ptrue, constants + 1 * vl); + + add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8); + apply_inv_sbox(ptrue, &state1, &state2, state + 8); + + svst1(ptrue, state + 0 * vl, state1); + svst1(ptrue, state + 1 * vl, state2); + + return true; +} diff --git a/arch/arm64-sve/library.h b/arch/arm64-sve/library.h new file mode 100644 index 0000000..c8f1cdd --- /dev/null +++ b/arch/arm64-sve/library.h @@ -0,0 +1,12 @@ +#ifndef CRYPTO_LIBRARY_H +#define CRYPTO_LIBRARY_H + +#include +#include + +#define STATE_WIDTH 12 + +bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); +bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]); + +#endif //CRYPTO_LIBRARY_H diff --git a/arch/arm64-sve/rpo_hash.h b/arch/arm64-sve/rpo_hash.h new file mode 100644 index 0000000..567298f --- /dev/null +++ b/arch/arm64-sve/rpo_hash.h @@ -0,0 +1,221 @@ +#ifndef RPO_SVE_RPO_HASH_H +#define RPO_SVE_RPO_HASH_H + +#include +#include +#include +#include + +#define COPY(NAME, VIN1, VIN2, SIN3) \ + svuint64_t NAME ## _1 = VIN1; \ + svuint64_t NAME ## _2 = VIN2; \ + uint64_t NAME ## _3[4]; \ + memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t)) + +#define MULTIPLY(PRED, DEST, OP) \ + mul(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3) + +#define SQUARE(PRED, NAME) \ + sq(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3) + +#define SQUARE_DEST(PRED, DEST, SRC) \ + COPY(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \ + SQUARE(PRED, DEST); + +#define POW_ACC(PRED, NAME, CNT, TAIL) \ + for (size_t i = 0; i < CNT; i++) { \ + SQUARE(PRED, NAME); \ + } \ + MULTIPLY(PRED, NAME, TAIL); + +#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \ + COPY(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \ + POW_ACC(PRED, DEST, CNT, TAIL) + +extern inline void add_constants( + svbool_t pg, + svuint64_t *state1, + svuint64_t *const1, + svuint64_t *state2, + svuint64_t *const2, + uint64_t *state3, + uint64_t *const3 +) { + uint64_t Ms = 0xFFFFFFFF00000001ull; + svuint64_t Mv = svindex_u64(Ms, 0); + + uint64_t p_1 = Ms - const3[0]; + uint64_t p_2 = Ms - const3[1]; + uint64_t p_3 = Ms - const3[2]; + uint64_t p_4 = Ms - const3[3]; + + uint64_t x_1, x_2, x_3, x_4; + uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1); + uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2); + uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3); + uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4); + + state3[0] = x_1 - (uint64_t)adj_1; + state3[1] = x_2 - (uint64_t)adj_2; + state3[2] = x_3 - (uint64_t)adj_3; + state3[3] = x_4 - (uint64_t)adj_4; + + svuint64_t p1 = svsub_x(pg, Mv, *const1); + svuint64_t p2 = svsub_x(pg, Mv, *const2); + + svuint64_t x1 = svsub_x(pg, *state1, p1); + svuint64_t x2 = svsub_x(pg, *state2, p2); + + svbool_t pt1 = svcmplt_u64(pg, *state1, p1); + svbool_t pt2 = svcmplt_u64(pg, *state2, p2); + + *state1 = svsub_m(pt1, x1, (uint32_t)-1); + *state2 = svsub_m(pt2, x2, (uint32_t)-1); +} + +extern inline void mul( + svbool_t pg, + svuint64_t *r1, + const svuint64_t *op1, + svuint64_t *r2, + const svuint64_t *op2, + uint64_t *r3, + const uint64_t *op3 +) { + __uint128_t x_1 = r3[0]; + __uint128_t x_2 = r3[1]; + __uint128_t x_3 = r3[2]; + __uint128_t x_4 = r3[3]; + + x_1 *= (__uint128_t) op3[0]; + x_2 *= (__uint128_t) op3[1]; + x_3 *= (__uint128_t) op3[2]; + x_4 *= (__uint128_t) op3[3]; + + uint64_t x0_1 = x_1; + uint64_t x0_2 = x_2; + uint64_t x0_3 = x_3; + uint64_t x0_4 = x_4; + + svuint64_t l1 = svmul_x(pg, *r1, *op1); + svuint64_t l2 = svmul_x(pg, *r2, *op2); + + uint64_t x1_1 = (x_1 >> 64); + uint64_t x1_2 = (x_2 >> 64); + uint64_t x1_3 = (x_3 >> 64); + uint64_t x1_4 = (x_4 >> 64); + + uint64_t a_1, a_2, a_3, a_4; + uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1); + uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2); + uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3); + uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4); + + svuint64_t ls1 = svlsl_x(pg, l1, 32); + svuint64_t ls2 = svlsl_x(pg, l2, 32); + + svuint64_t a1 = svadd_x(pg, l1, ls1); + svuint64_t a2 = svadd_x(pg, l2, ls2); + + svbool_t e1 = svcmplt(pg, a1, l1); + svbool_t e2 = svcmplt(pg, a2, l2); + + svuint64_t as1 = svlsr_x(pg, a1, 32); + svuint64_t as2 = svlsr_x(pg, a2, 32); + + svuint64_t b1 = svsub_x(pg, a1, as1); + svuint64_t b2 = svsub_x(pg, a2, as2); + + b1 = svsub_m(e1, b1, 1); + b2 = svsub_m(e2, b2, 1); + + uint64_t b_1 = a_1 - (a_1 >> 32) - e_1; + uint64_t b_2 = a_2 - (a_2 >> 32) - e_2; + uint64_t b_3 = a_3 - (a_3 >> 32) - e_3; + uint64_t b_4 = a_4 - (a_4 >> 32) - e_4; + + uint64_t r_1, r_2, r_3, r_4; + uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1); + uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2); + uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3); + uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4); + + svuint64_t h1 = svmulh_x(pg, *r1, *op1); + svuint64_t h2 = svmulh_x(pg, *r2, *op2); + + svuint64_t tr1 = svsub_x(pg, h1, b1); + svuint64_t tr2 = svsub_x(pg, h2, b2); + + svbool_t c1 = svcmplt_u64(pg, h1, b1); + svbool_t c2 = svcmplt_u64(pg, h2, b2); + + *r1 = svsub_m(c1, tr1, (uint32_t) -1); + *r2 = svsub_m(c2, tr2, (uint32_t) -1); + + uint32_t minus1_1 = 0 - c_1; + uint32_t minus1_2 = 0 - c_2; + uint32_t minus1_3 = 0 - c_3; + uint32_t minus1_4 = 0 - c_4; + + r3[0] = r_1 - (uint64_t)minus1_1; + r3[1] = r_2 - (uint64_t)minus1_2; + r3[2] = r_3 - (uint64_t)minus1_3; + r3[3] = r_4 - (uint64_t)minus1_4; +} + +extern inline void sq(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) { + mul(pg, a, a, b, b, c, c); +} + +extern inline void apply_sbox( + svbool_t pg, + svuint64_t *state1, + svuint64_t *state2, + uint64_t *state3 +) { + COPY(x, *state1, *state2, state3); // copy input to x + SQUARE(pg, x); // x contains input^2 + mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3 + SQUARE(pg, x); // x contains input^4 + mul(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7 +} + +extern inline void apply_inv_sbox( + svbool_t pg, + svuint64_t *state_1, + svuint64_t *state_2, + uint64_t *state_3 +) { + // base^10 + COPY(t1, *state_1, *state_2, state_3); + SQUARE(pg, t1); + + // base^100 + SQUARE_DEST(pg, t2, t1); + + // base^100100 + POW_ACC_DEST(pg, t3, 3, t2, t2); + + // base^100100100100 + POW_ACC_DEST(pg, t4, 6, t3, t3); + + // compute base^100100100100100100100100 + POW_ACC_DEST(pg, t5, 12, t4, t4); + + // compute base^100100100100100100100100100100 + POW_ACC_DEST(pg, t6, 6, t5, t3); + + // compute base^1001001001001001001001001001000100100100100100100100100100100 + POW_ACC_DEST(pg, t7, 31, t6, t6); + + // compute base^1001001001001001001001001001000110110110110110110110110110110111 + SQUARE(pg, t7); + MULTIPLY(pg, t7, t6); + SQUARE(pg, t7); + SQUARE(pg, t7); + MULTIPLY(pg, t7, t1); + MULTIPLY(pg, t7, t2); + mul(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3); +} + +#endif //RPO_SVE_RPO_HASH_H diff --git a/arch/arm64-sve/test.c b/arch/arm64-sve/test.c new file mode 100644 index 0000000..78e2f50 --- /dev/null +++ b/arch/arm64-sve/test.c @@ -0,0 +1,27 @@ +#include +#include "library.h" + +void print_array(size_t len, uint64_t arr[len]); + +int main() { + uint64_t C[STATE_WIDTH] = {1, 1, 1, 1 ,1, 1, 1, 1 ,1, 1, 1, 1}; + uint64_t T[STATE_WIDTH] = {1, 2, 3, 4, 1, 2, 3, 4,1, 2, 3, 4}; + + add_constants_and_apply_sbox(T, C); + add_constants_and_apply_inv_sbox(T, C); + + print_array(STATE_WIDTH, T); + + return 0; +} + +void print_array(size_t len, uint64_t arr[len]) +{ + printf("["); + for (size_t i = 0; i < len; i++) + { + printf("%lu ", arr[i]); + } + + printf("]\n"); +} diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..7d95857 --- /dev/null +++ b/build.rs @@ -0,0 +1,17 @@ +fn main() { + #[cfg(feature = "arch-arm64-sve")] + compile_arch_arm64_sve(); +} + +#[cfg(feature = "arch-arm64-sve")] +fn compile_arch_arm64_sve() { + println!("cargo:rerun-if-changed=arch/arm64-sve/library.c"); + println!("cargo:rerun-if-changed=arch/arm64-sve/library.h"); + println!("cargo:rerun-if-changed=arch/arm64-sve/rpo_hash.h"); + + cc::Build::new() + .file("arch/arm64-sve/library.c") + .flag("-march=armv8-a+sve") + .flag("-O3") + .compile("rpo_sve"); +} diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index 95f2c97..dc7df3f 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -10,6 +10,19 @@ use mds_freq::mds_multiply_freq; #[cfg(test)] mod tests; +#[cfg(feature = "arch-arm64-sve")] +#[link(name = "rpo_sve", kind = "static")] +extern "C" { + fn add_constants_and_apply_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; + fn add_constants_and_apply_inv_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; +} + // CONSTANTS // ================================================================================================ @@ -345,18 +358,65 @@ impl Rpo256 { pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) { // apply first half of RPO round Self::apply_mds(state); - Self::add_constants(state, &ARK1[round]); - Self::apply_sbox(state); + if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { + Self::add_constants(state, &ARK1[round]); + Self::apply_sbox(state); + } // apply second half of RPO round Self::apply_mds(state); - Self::add_constants(state, &ARK2[round]); - Self::apply_inv_sbox(state); + if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { + Self::add_constants(state, &ARK2[round]); + Self::apply_inv_sbox(state); + } } // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- + #[inline(always)] + #[cfg(feature = "arch-arm64-sve")] + fn optimized_add_constants_and_apply_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + unsafe { + add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) + } + } + + #[inline(always)] + #[cfg(not(feature = "arch-arm64-sve"))] + fn optimized_add_constants_and_apply_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], + ) -> bool { + false + } + + #[inline(always)] + #[cfg(feature = "arch-arm64-sve")] + fn optimized_add_constants_and_apply_inv_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], + ) -> bool { + unsafe { + add_constants_and_apply_inv_sbox( + state.as_mut_ptr() as *mut u64, + ark.as_ptr() as *const u64, + ) + } + } + + #[inline(always)] + #[cfg(not(feature = "arch-arm64-sve"))] + fn optimized_add_constants_and_apply_inv_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], + ) -> bool { + false + } + #[inline(always)] fn apply_mds(state: &mut [Felt; STATE_WIDTH]) { let mut result = [ZERO; STATE_WIDTH]; From 01be4d6b9d0b0bbb50623b96e26c18f6030e0368 Mon Sep 17 00:00:00 2001 From: Grzegorz Swirski Date: Sun, 24 Sep 2023 22:23:38 +0200 Subject: [PATCH 28/32] refactor: move arch specific code to rpo folder, don't run SVE on CI --- .github/workflows/ci.yml | 4 ++-- .pre-commit-config.yaml | 4 ++-- arch/arm64-sve/CMakeLists.txt | 10 ---------- arch/arm64-sve/{ => rpo}/library.c | 0 arch/arm64-sve/{ => rpo}/library.h | 0 arch/arm64-sve/{ => rpo}/rpo_hash.h | 0 arch/arm64-sve/test.c | 27 --------------------------- build.rs | 8 ++++---- 8 files changed, 8 insertions(+), 45 deletions(-) delete mode 100644 arch/arm64-sve/CMakeLists.txt rename arch/arm64-sve/{ => rpo}/library.c (100%) rename arch/arm64-sve/{ => rpo}/library.h (100%) rename arch/arm64-sve/{ => rpo}/rpo_hash.h (100%) delete mode 100644 arch/arm64-sve/test.c diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6ee9fb3..2ff1609 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: matrix: toolchain: [stable, nightly] os: [ubuntu] - features: [--all-features, --no-default-features] + features: ["--features default,std,serde", --no-default-features] steps: - uses: actions/checkout@main - name: Install rust @@ -59,7 +59,7 @@ jobs: strategy: fail-fast: false matrix: - features: [--all-features, --no-default-features] + features: ["--features default,std,serde", --no-default-features] steps: - uses: actions/checkout@main - name: Install minimal nightly with clippy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d00cf26..4f909e8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,8 +35,8 @@ repos: name: Cargo check --all-targets --no-default-features args: ["+stable", "check", "--all-targets", "--no-default-features"] - id: cargo - name: Cargo check --all-targets --all-features - args: ["+stable", "check", "--all-targets", "--all-features"] + name: Cargo check --all-targets --features default,std,serde + args: ["+stable", "check", "--all-targets", "--features", "default,std,serde"] # Unlike fmt, clippy will not be automatically applied - id: cargo name: Cargo clippy diff --git a/arch/arm64-sve/CMakeLists.txt b/arch/arm64-sve/CMakeLists.txt deleted file mode 100644 index 40710b1..0000000 --- a/arch/arm64-sve/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -cmake_minimum_required(VERSION 3.0) -project(rpo_sve C) - -set(CMAKE_C_STANDARD 23) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve -Wall -Wextra -pedantic -g -O3") - -add_library(rpo_sve library.c rpo_hash.h) - -add_executable(rpo_test test.c) -target_link_libraries(rpo_test rpo_sve) diff --git a/arch/arm64-sve/library.c b/arch/arm64-sve/rpo/library.c similarity index 100% rename from arch/arm64-sve/library.c rename to arch/arm64-sve/rpo/library.c diff --git a/arch/arm64-sve/library.h b/arch/arm64-sve/rpo/library.h similarity index 100% rename from arch/arm64-sve/library.h rename to arch/arm64-sve/rpo/library.h diff --git a/arch/arm64-sve/rpo_hash.h b/arch/arm64-sve/rpo/rpo_hash.h similarity index 100% rename from arch/arm64-sve/rpo_hash.h rename to arch/arm64-sve/rpo/rpo_hash.h diff --git a/arch/arm64-sve/test.c b/arch/arm64-sve/test.c deleted file mode 100644 index 78e2f50..0000000 --- a/arch/arm64-sve/test.c +++ /dev/null @@ -1,27 +0,0 @@ -#include -#include "library.h" - -void print_array(size_t len, uint64_t arr[len]); - -int main() { - uint64_t C[STATE_WIDTH] = {1, 1, 1, 1 ,1, 1, 1, 1 ,1, 1, 1, 1}; - uint64_t T[STATE_WIDTH] = {1, 2, 3, 4, 1, 2, 3, 4,1, 2, 3, 4}; - - add_constants_and_apply_sbox(T, C); - add_constants_and_apply_inv_sbox(T, C); - - print_array(STATE_WIDTH, T); - - return 0; -} - -void print_array(size_t len, uint64_t arr[len]) -{ - printf("["); - for (size_t i = 0; i < len; i++) - { - printf("%lu ", arr[i]); - } - - printf("]\n"); -} diff --git a/build.rs b/build.rs index 7d95857..f65f075 100644 --- a/build.rs +++ b/build.rs @@ -5,12 +5,12 @@ fn main() { #[cfg(feature = "arch-arm64-sve")] fn compile_arch_arm64_sve() { - println!("cargo:rerun-if-changed=arch/arm64-sve/library.c"); - println!("cargo:rerun-if-changed=arch/arm64-sve/library.h"); - println!("cargo:rerun-if-changed=arch/arm64-sve/rpo_hash.h"); + println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/library.c"); + println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/library.h"); + println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/rpo_hash.h"); cc::Build::new() - .file("arch/arm64-sve/library.c") + .file("arch/arm64-sve/rpo/library.c") .flag("-march=armv8-a+sve") .flag("-O3") .compile("rpo_sve"); From 8078021aff0717eaa4add176bf6ac8ea5077a348 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:45:18 +0200 Subject: [PATCH 29/32] feat: Falcon 512 signature --- .github/workflows/ci.yml | 6 + .gitmodules | 3 + Cargo.toml | 12 +- PQClean | 1 + build.rs | 24 + rustfmt.toml | 1 + src/dsa/mod.rs | 1 + src/dsa/rpo_falcon512/error.rs | 55 ++ src/dsa/rpo_falcon512/falcon_c/api_rpo.h | 66 +++ src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c | 387 +++++++++++++ src/dsa/rpo_falcon512/falcon_c/rpo.c | 582 ++++++++++++++++++++ src/dsa/rpo_falcon512/falcon_c/rpo.h | 83 +++ src/dsa/rpo_falcon512/ffi.rs | 189 +++++++ src/dsa/rpo_falcon512/keys.rs | 227 ++++++++ src/dsa/rpo_falcon512/mod.rs | 60 ++ src/dsa/rpo_falcon512/polynomial.rs | 277 ++++++++++ src/dsa/rpo_falcon512/signature.rs | 262 +++++++++ src/lib.rs | 1 + src/merkle/merkle_tree.rs | 18 +- src/merkle/mmr/full.rs | 10 +- src/merkle/mmr/tests.rs | 6 +- src/merkle/partial_mt/mod.rs | 6 +- src/merkle/path.rs | 11 +- src/merkle/simple_smt/mod.rs | 5 +- src/merkle/simple_smt/tests.rs | 18 +- src/merkle/store/mod.rs | 56 +- src/merkle/tiered_smt/values.rs | 5 +- src/utils/mod.rs | 5 +- 28 files changed, 2263 insertions(+), 114 deletions(-) create mode 100644 .gitmodules create mode 160000 PQClean create mode 100644 src/dsa/mod.rs create mode 100644 src/dsa/rpo_falcon512/error.rs create mode 100644 src/dsa/rpo_falcon512/falcon_c/api_rpo.h create mode 100644 src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c create mode 100644 src/dsa/rpo_falcon512/falcon_c/rpo.c create mode 100644 src/dsa/rpo_falcon512/falcon_c/rpo.h create mode 100644 src/dsa/rpo_falcon512/ffi.rs create mode 100644 src/dsa/rpo_falcon512/keys.rs create mode 100644 src/dsa/rpo_falcon512/mod.rs create mode 100644 src/dsa/rpo_falcon512/polynomial.rs create mode 100644 src/dsa/rpo_falcon512/signature.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ff1609..d89fbb8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,6 +19,8 @@ jobs: args: [--no-default-features --target wasm32-unknown-unknown] steps: - uses: actions/checkout@main + with: + submodules: recursive - name: Install rust uses: actions-rs/toolchain@v1 with: @@ -42,6 +44,8 @@ jobs: features: ["--features default,std,serde", --no-default-features] steps: - uses: actions/checkout@main + with: + submodules: recursive - name: Install rust uses: actions-rs/toolchain@v1 with: @@ -62,6 +66,8 @@ jobs: features: ["--features default,std,serde", --no-default-features] steps: - uses: actions/checkout@main + with: + submodules: recursive - name: Install minimal nightly with clippy uses: actions-rs/toolchain@v1 with: diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..88ae99f --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "PQClean"] + path = PQClean + url = https://github.com/PQClean/PQClean.git diff --git a/Cargo.toml b/Cargo.toml index 4c92e57..c77a1d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ name = "miden-crypto" path = "src/main.rs" bench = false doctest = false -required-features = ["std"] +required-features = ["executable"] [[bench]] name = "hash" @@ -34,12 +34,15 @@ harness = false [features] arch-arm64-sve = ["dep:cc"] default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] -std = ["blake3/std", "winter_crypto/std", "winter_math/std", "winter_utils/std", "rand_utils"] +executable = ["dep:clap", "dep:rand_utils", "std"] +std = ["blake3/std", "dep:cc", "dep:libc", "dep:rand", "winter_crypto/std", "winter_math/std", "winter_utils/std"] serde = ["winter_math/serde", "dep:serde", "serde/alloc"] [dependencies] blake3 = { version = "1.4", default-features = false } -clap = { version = "4.3.21", features = ["derive"] } +clap = { version = "4.3", features = ["derive"], optional = true} +libc = { version = "0.2", optional = true, default-features = false } +rand = { version = "0.8", optional = true, default-features = false } winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } @@ -52,4 +55,5 @@ proptest = "1.1.0" rand_utils = { version = "0.6", package = "winter-rand-utils" } [build-dependencies] -cc = { version = "1.0.79", optional = true } +cc = { version = "1.0", features = ["parallel"], optional = true } +glob = "*" diff --git a/PQClean b/PQClean new file mode 160000 index 0000000..c3abebf --- /dev/null +++ b/PQClean @@ -0,0 +1 @@ +Subproject commit c3abebf4ab1ff516ffa71e6337f06d898952c299 diff --git a/build.rs b/build.rs index f65f075..e27f9df 100644 --- a/build.rs +++ b/build.rs @@ -1,8 +1,32 @@ fn main() { + #[cfg(feature = "std")] + compile_rpo_falcon(); + #[cfg(feature = "arch-arm64-sve")] compile_arch_arm64_sve(); } +#[cfg(feature = "std")] +fn compile_rpo_falcon() { + use std::path::PathBuf; + + let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect(); + let common_dir: PathBuf = ["PQClean", "common"].iter().collect(); + let rpo_dir: PathBuf = ["src", "dsa", "rpo_falcon512", "falcon_c"].iter().collect(); + + let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap(); + let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap(); + let rpo_files = glob::glob(rpo_dir.join("*.c").to_str().unwrap()).unwrap(); + + cc::Build::new() + .include(&common_dir) + .include(target_dir) + .files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned())) + .files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned())) + .files(rpo_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned())) + .compile("falcon-512_clean"); +} + #[cfg(feature = "arch-arm64-sve")] fn compile_arch_arm64_sve() { println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/library.c"); diff --git a/rustfmt.toml b/rustfmt.toml index 93e66a1..d73df8c 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -16,5 +16,6 @@ newline_style = "Unix" #normalize_doc_attributes = true #reorder_impl_items = true single_line_if_else_max_width = 60 +struct_lit_width = 40 use_field_init_shorthand = true use_try_shorthand = true diff --git a/src/dsa/mod.rs b/src/dsa/mod.rs new file mode 100644 index 0000000..80fb8a8 --- /dev/null +++ b/src/dsa/mod.rs @@ -0,0 +1 @@ +pub mod rpo_falcon512; diff --git a/src/dsa/rpo_falcon512/error.rs b/src/dsa/rpo_falcon512/error.rs new file mode 100644 index 0000000..447a82c --- /dev/null +++ b/src/dsa/rpo_falcon512/error.rs @@ -0,0 +1,55 @@ +use super::{LOG_N, MODULUS, PK_LEN}; +use core::fmt; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FalconError { + KeyGenerationFailed, + PubKeyDecodingExtraData, + PubKeyDecodingInvalidCoefficient(u32), + PubKeyDecodingInvalidLength(usize), + PubKeyDecodingInvalidTag(u8), + SigDecodingTooBigHighBits(u32), + SigDecodingInvalidRemainder, + SigDecodingNonZeroUnusedBitsLastByte, + SigDecodingMinusZero, + SigDecodingIncorrectEncodingAlgorithm, + SigDecodingNotSupportedDegree(u8), + SigGenerationFailed, +} + +impl fmt::Display for FalconError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use FalconError::*; + match self { + KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"), + PubKeyDecodingExtraData => { + write!(f, "Failed to decode public key: input not fully consumed") + } + PubKeyDecodingInvalidCoefficient(val) => { + write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}") + } + PubKeyDecodingInvalidLength(len) => { + write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}") + } + PubKeyDecodingInvalidTag(byte) => { + write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}") + } + SigDecodingTooBigHighBits(m) => { + write!(f, "Failed to decode signature: high bits {m} exceed 2048") + } + SigDecodingInvalidRemainder => { + write!(f, "Failed to decode signature: incorrect remaining data") + } + SigDecodingNonZeroUnusedBitsLastByte => { + write!(f, "Failed to decode signature: Non-zero unused bits in the last byte") + } + SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"), + SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"), + SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"), + SigGenerationFailed => write!(f, "Failed to generate a signature"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for FalconError {} diff --git a/src/dsa/rpo_falcon512/falcon_c/api_rpo.h b/src/dsa/rpo_falcon512/falcon_c/api_rpo.h new file mode 100644 index 0000000..bdcc3ec --- /dev/null +++ b/src/dsa/rpo_falcon512/falcon_c/api_rpo.h @@ -0,0 +1,66 @@ +#include +#include + +#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281 +#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897 +#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666 + +/* + * Generate a new key pair. Public key goes into pk[], private key in sk[]. + * Key sizes are exact (in bytes): + * public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES + * private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES + * + * Return value: 0 on success, -1 on error. + * + * Note: This implementation follows the reference implementation in PQClean + * https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 + * verbatim except for the sections that are marked otherwise. + */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( + uint8_t *pk, uint8_t *sk); + +/* + * Generate a new key pair from seed. Public key goes into pk[], private key in sk[]. + * Key sizes are exact (in bytes): + * public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES + * private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES + * + * Return value: 0 on success, -1 on error. + */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( + uint8_t *pk, uint8_t *sk, unsigned char *seed); + +/* + * Compute a signature on a provided message (m, mlen), with a given + * private key (sk). Signature is written in sig[], with length written + * into *siglen. Signature length is variable; maximum signature length + * (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES. + * + * sig[], m[] and sk[] may overlap each other arbitrarily. + * + * Return value: 0 on success, -1 on error. + * + * Note: This implementation follows the reference implementation in PQClean + * https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 + * verbatim except for the sections that are marked otherwise. + */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( + uint8_t *sig, size_t *siglen, + const uint8_t *m, size_t mlen, const uint8_t *sk); + +/* + * Verify a signature (sig, siglen) on a message (m, mlen) with a given + * public key (pk). + * + * sig[], m[] and pk[] may overlap each other arbitrarily. + * + * Return value: 0 on success, -1 on error. + * + * Note: This implementation follows the reference implementation in PQClean + * https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512 + * verbatim except for the sections that are marked otherwise. + */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( + const uint8_t *sig, size_t siglen, + const uint8_t *m, size_t mlen, const uint8_t *pk); diff --git a/src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c b/src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c new file mode 100644 index 0000000..a1294d2 --- /dev/null +++ b/src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c @@ -0,0 +1,387 @@ +/* + * Wrapper for implementing the PQClean API. + */ + +#include +#include "randombytes.h" +#include "api_rpo.h" +#include "inner.h" +#include "rpo.h" + +#define NONCELEN 40 + +/* + * Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024) + * + * private key: + * header byte: 0101nnnn + * private f (6 or 5 bits by element, depending on degree) + * private g (6 or 5 bits by element, depending on degree) + * private F (8 bits by element) + * + * public key: + * header byte: 0000nnnn + * public h (14 bits by element) + * + * signature: + * header byte: 0011nnnn + * nonce 40 bytes + * value (12 bits by element) + * + * message + signature: + * signature length (2 bytes, big-endian) + * nonce 40 bytes + * message + * header byte: 0010nnnn + * value (12 bits by element) + * (signature length is 1+len(value), not counting the nonce) + */ + +/* see api_rpo.h */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( + uint8_t *pk, uint8_t *sk, unsigned char *seed) +{ + union + { + uint8_t b[FALCON_KEYGEN_TEMP_9]; + uint64_t dummy_u64; + fpr dummy_fpr; + } tmp; + int8_t f[512], g[512], F[512]; + uint16_t h[512]; + inner_shake256_context rng; + size_t u, v; + + /* + * Generate key pair. + */ + inner_shake256_init(&rng); + inner_shake256_inject(&rng, seed, sizeof seed); + inner_shake256_flip(&rng); + PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b); + inner_shake256_ctx_release(&rng); + + /* + * Encode private key. + */ + sk[0] = 0x50 + 9; + u = 1; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, + f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]); + if (v == 0) + { + return -1; + } + u += v; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, + g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]); + if (v == 0) + { + return -1; + } + u += v; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode( + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u, + F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]); + if (v == 0) + { + return -1; + } + u += v; + if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES) + { + return -1; + } + + /* + * Encode public key. + */ + pk[0] = 0x00 + 9; + v = PQCLEAN_FALCON512_CLEAN_modq_encode( + pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1, + h, 9); + if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) + { + return -1; + } + + return 0; +} + +int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( + uint8_t *pk, uint8_t *sk) +{ + unsigned char seed[48]; + + /* + * Generate a random seed. + */ + randombytes(seed, sizeof seed); + + return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed); +} + +/* + * Compute the signature. nonce[] receives the nonce and must have length + * NONCELEN bytes. sigbuf[] receives the signature value (without nonce + * or header byte), with *sigbuflen providing the maximum value length and + * receiving the actual value length. + * + * If a signature could be computed but not encoded because it would + * exceed the output buffer size, then a new signature is computed. If + * the provided buffer size is too low, this could loop indefinitely, so + * the caller must provide a size that can accommodate signatures with a + * large enough probability. + * + * Return value: 0 on success, -1 on error. + */ +static int +do_sign(uint8_t *nonce, uint8_t *sigbuf, size_t *sigbuflen, + const uint8_t *m, size_t mlen, const uint8_t *sk) +{ + union + { + uint8_t b[72 * 512]; + uint64_t dummy_u64; + fpr dummy_fpr; + } tmp; + int8_t f[512], g[512], F[512], G[512]; + struct + { + int16_t sig[512]; + uint16_t hm[512]; + } r; + unsigned char seed[48]; + inner_shake256_context sc; + rpo128_context rc; + size_t u, v; + + /* + * Decode the private key. + */ + if (sk[0] != 0x50 + 9) + { + return -1; + } + u = 1; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( + f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); + if (v == 0) + { + return -1; + } + u += v; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( + g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); + if (v == 0) + { + return -1; + } + u += v; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( + F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9], + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); + if (v == 0) + { + return -1; + } + u += v; + if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES) + { + return -1; + } + if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b)) + { + return -1; + } + + /* + * Create a random nonce (40 bytes). + */ + randombytes(nonce, NONCELEN); + + /* ==== Start: Deviation from the reference implementation ================================= */ + + // Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that + // the conversion to field elements succeeds + uint8_t buffer[64]; + memset(buffer, 0, 64); + for (size_t i = 0; i < 8; i++) + { + buffer[8 * i] = nonce[5 * i]; + buffer[8 * i + 1] = nonce[5 * i + 1]; + buffer[8 * i + 2] = nonce[5 * i + 2]; + buffer[8 * i + 3] = nonce[5 * i + 3]; + buffer[8 * i + 4] = nonce[5 * i + 4]; + } + + /* + * Hash message nonce + message into a vector. + */ + rpo128_init(&rc); + rpo128_absorb(&rc, buffer, NONCELEN + 24); + rpo128_absorb(&rc, m, mlen); + rpo128_finalize(&rc); + PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9); + rpo128_release(&rc); + + /* ==== End: Deviation from the reference implementation =================================== */ + + /* + * Initialize a RNG. + */ + randombytes(seed, sizeof seed); + inner_shake256_init(&sc); + inner_shake256_inject(&sc, seed, sizeof seed); + inner_shake256_flip(&sc); + + /* + * Compute and return the signature. This loops until a signature + * value is found that fits in the provided buffer. + */ + for (;;) + { + PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b); + v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9); + if (v != 0) + { + inner_shake256_ctx_release(&sc); + *sigbuflen = v; + return 0; + } + } +} + +/* + * Verify a signature. The nonce has size NONCELEN bytes. sigbuf[] + * (of size sigbuflen) contains the signature value, not including the + * header byte or nonce. Return value is 0 on success, -1 on error. + */ +static int +do_verify( + const uint8_t *nonce, const uint8_t *sigbuf, size_t sigbuflen, + const uint8_t *m, size_t mlen, const uint8_t *pk) +{ + union + { + uint8_t b[2 * 512]; + uint64_t dummy_u64; + fpr dummy_fpr; + } tmp; + uint16_t h[512], hm[512]; + int16_t sig[512]; + rpo128_context rc; + + /* + * Decode public key. + */ + if (pk[0] != 0x00 + 9) + { + return -1; + } + if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9, + pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) + != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) + { + return -1; + } + PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9); + + /* + * Decode signature. + */ + if (sigbuflen == 0) + { + return -1; + } + if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen) + { + return -1; + } + + /* ==== Start: Deviation from the reference implementation ================================= */ + + /* + * Hash nonce + message into a vector. + */ + + // Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that + // the conversion to field elements succeeds + uint8_t buffer[64]; + memset(buffer, 0, 64); + for (size_t i = 0; i < 8; i++) + { + buffer[8 * i] = nonce[5 * i]; + buffer[8 * i + 1] = nonce[5 * i + 1]; + buffer[8 * i + 2] = nonce[5 * i + 2]; + buffer[8 * i + 3] = nonce[5 * i + 3]; + buffer[8 * i + 4] = nonce[5 * i + 4]; + } + + rpo128_init(&rc); + rpo128_absorb(&rc, buffer, NONCELEN + 24); + rpo128_absorb(&rc, m, mlen); + rpo128_finalize(&rc); + PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9); + rpo128_release(&rc); + + /* === End: Deviation from the reference implementation ==================================== */ + + /* + * Verify signature. + */ + if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b)) + { + return -1; + } + return 0; +} + +/* see api_rpo.h */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( + uint8_t *sig, size_t *siglen, + const uint8_t *m, size_t mlen, const uint8_t *sk) +{ + /* + * The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for + * the signed message object (as produced by crypto_sign()) + * and includes a two-byte length value, so we take care here + * to only generate signatures that are two bytes shorter than + * the maximum. This is done to ensure that crypto_sign() + * and crypto_sign_signature() produce the exact same signature + * value, if used on the same message, with the same private key, + * and using the same output from randombytes() (this is for + * reproducibility of tests). + */ + size_t vlen; + + vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3; + if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0) + { + return -1; + } + sig[0] = 0x30 + 9; + *siglen = 1 + NONCELEN + vlen; + return 0; +} + +/* see api_rpo.h */ +int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( + const uint8_t *sig, size_t siglen, + const uint8_t *m, size_t mlen, const uint8_t *pk) +{ + if (siglen < 1 + NONCELEN) + { + return -1; + } + if (sig[0] != 0x30 + 9) + { + return -1; + } + return do_verify(sig + 1, + sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk); +} diff --git a/src/dsa/rpo_falcon512/falcon_c/rpo.c b/src/dsa/rpo_falcon512/falcon_c/rpo.c new file mode 100644 index 0000000..824f8a2 --- /dev/null +++ b/src/dsa/rpo_falcon512/falcon_c/rpo.c @@ -0,0 +1,582 @@ +/* + * RPO implementation. + */ + +#include +#include +#include + +/* ================================================================================================ + * Modular Arithmetic + */ + +#define P 0xFFFFFFFF00000001 +#define M 12289 + +// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go +uint64_t add_mod_p(uint64_t a, uint64_t b) +{ + a = P - a; + uint64_t res = b - a; + if (b < a) + res += P; + return res; +} + +uint64_t sub_mod_p(uint64_t a, uint64_t b) +{ + uint64_t r = a - b; + if (a < b) + r += P; + return r; +} + +uint64_t reduce_mod_p(uint64_t b, uint64_t a) +{ + uint32_t d = b >> 32, + c = b; + if (a >= P) + a -= P; + a = sub_mod_p(a, c); + a = sub_mod_p(a, d); + a = add_mod_p(a, ((uint64_t)c) << 32); + return a; +} + +uint64_t mult_mod_p(uint64_t x, uint64_t y) +{ + uint32_t a = x, + b = x >> 32, + c = y, + d = y >> 32; + + /* first synthesize the product using 32*32 -> 64 bit multiplies */ + x = b * (uint64_t)c; /* b*c */ + y = a * (uint64_t)d; /* a*d */ + uint64_t e = a * (uint64_t)c, /* a*c */ + f = b * (uint64_t)d, /* b*d */ + t; + + x += y; /* b*c + a*d */ + /* carry? */ + if (x < y) + f += 1LL << 32; /* carry into upper 32 bits - can't overflow */ + + t = x << 32; + e += t; /* a*c + LSW(b*c + a*d) */ + /* carry? */ + if (e < t) + f += 1; /* carry into upper 64 bits - can't overflow*/ + t = x >> 32; + f += t; /* b*d + MSW(b*c + a*d) */ + /* can't overflow */ + + /* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */ + return reduce_mod_p(f, e); +} + +/* ================================================================================================ + * RPO128 Permutation + */ + +static const uint64_t STATE_WIDTH = 12; +static const uint64_t NUM_ROUNDS = 7; + +/* + * MDS matrix + */ +const uint64_t MDS[12][12] = { + { 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 }, + { 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 }, + { 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 }, + { 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 }, + { 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 }, + { 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 }, + { 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 }, + { 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 }, + { 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 }, + { 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 }, + { 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 }, + { 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 }, +}; + +/* + * Round constants. + */ +const uint64_t ARK1[7][12] = { + { + 5789762306288267392ULL, + 6522564764413701783ULL, + 17809893479458208203ULL, + 107145243989736508ULL, + 6388978042437517382ULL, + 15844067734406016715ULL, + 9975000513555218239ULL, + 3344984123768313364ULL, + 9959189626657347191ULL, + 12960773468763563665ULL, + 9602914297752488475ULL, + 16657542370200465908ULL, + }, + { + 12987190162843096997ULL, + 653957632802705281ULL, + 4441654670647621225ULL, + 4038207883745915761ULL, + 5613464648874830118ULL, + 13222989726778338773ULL, + 3037761201230264149ULL, + 16683759727265180203ULL, + 8337364536491240715ULL, + 3227397518293416448ULL, + 8110510111539674682ULL, + 2872078294163232137ULL, + }, + { + 18072785500942327487ULL, + 6200974112677013481ULL, + 17682092219085884187ULL, + 10599526828986756440ULL, + 975003873302957338ULL, + 8264241093196931281ULL, + 10065763900435475170ULL, + 2181131744534710197ULL, + 6317303992309418647ULL, + 1401440938888741532ULL, + 8884468225181997494ULL, + 13066900325715521532ULL, + }, + { + 5674685213610121970ULL, + 5759084860419474071ULL, + 13943282657648897737ULL, + 1352748651966375394ULL, + 17110913224029905221ULL, + 1003883795902368422ULL, + 4141870621881018291ULL, + 8121410972417424656ULL, + 14300518605864919529ULL, + 13712227150607670181ULL, + 17021852944633065291ULL, + 6252096473787587650ULL, + }, + { + 4887609836208846458ULL, + 3027115137917284492ULL, + 9595098600469470675ULL, + 10528569829048484079ULL, + 7864689113198939815ULL, + 17533723827845969040ULL, + 5781638039037710951ULL, + 17024078752430719006ULL, + 109659393484013511ULL, + 7158933660534805869ULL, + 2955076958026921730ULL, + 7433723648458773977ULL, + }, + { + 16308865189192447297ULL, + 11977192855656444890ULL, + 12532242556065780287ULL, + 14594890931430968898ULL, + 7291784239689209784ULL, + 5514718540551361949ULL, + 10025733853830934803ULL, + 7293794580341021693ULL, + 6728552937464861756ULL, + 6332385040983343262ULL, + 13277683694236792804ULL, + 2600778905124452676ULL, + }, + { + 7123075680859040534ULL, + 1034205548717903090ULL, + 7717824418247931797ULL, + 3019070937878604058ULL, + 11403792746066867460ULL, + 10280580802233112374ULL, + 337153209462421218ULL, + 13333398568519923717ULL, + 3596153696935337464ULL, + 8104208463525993784ULL, + 14345062289456085693ULL, + 17036731477169661256ULL, + }}; + +const uint64_t ARK2[7][12] = { + { + 6077062762357204287ULL, + 15277620170502011191ULL, + 5358738125714196705ULL, + 14233283787297595718ULL, + 13792579614346651365ULL, + 11614812331536767105ULL, + 14871063686742261166ULL, + 10148237148793043499ULL, + 4457428952329675767ULL, + 15590786458219172475ULL, + 10063319113072092615ULL, + 14200078843431360086ULL, + }, + { + 6202948458916099932ULL, + 17690140365333231091ULL, + 3595001575307484651ULL, + 373995945117666487ULL, + 1235734395091296013ULL, + 14172757457833931602ULL, + 707573103686350224ULL, + 15453217512188187135ULL, + 219777875004506018ULL, + 17876696346199469008ULL, + 17731621626449383378ULL, + 2897136237748376248ULL, + }, + { + 8023374565629191455ULL, + 15013690343205953430ULL, + 4485500052507912973ULL, + 12489737547229155153ULL, + 9500452585969030576ULL, + 2054001340201038870ULL, + 12420704059284934186ULL, + 355990932618543755ULL, + 9071225051243523860ULL, + 12766199826003448536ULL, + 9045979173463556963ULL, + 12934431667190679898ULL, + }, + { + 18389244934624494276ULL, + 16731736864863925227ULL, + 4440209734760478192ULL, + 17208448209698888938ULL, + 8739495587021565984ULL, + 17000774922218161967ULL, + 13533282547195532087ULL, + 525402848358706231ULL, + 16987541523062161972ULL, + 5466806524462797102ULL, + 14512769585918244983ULL, + 10973956031244051118ULL, + }, + { + 6982293561042362913ULL, + 14065426295947720331ULL, + 16451845770444974180ULL, + 7139138592091306727ULL, + 9012006439959783127ULL, + 14619614108529063361ULL, + 1394813199588124371ULL, + 4635111139507788575ULL, + 16217473952264203365ULL, + 10782018226466330683ULL, + 6844229992533662050ULL, + 7446486531695178711ULL, + }, + { + 3736792340494631448ULL, + 577852220195055341ULL, + 6689998335515779805ULL, + 13886063479078013492ULL, + 14358505101923202168ULL, + 7744142531772274164ULL, + 16135070735728404443ULL, + 12290902521256031137ULL, + 12059913662657709804ULL, + 16456018495793751911ULL, + 4571485474751953524ULL, + 17200392109565783176ULL, + }, + { + 17130398059294018733ULL, + 519782857322261988ULL, + 9625384390925085478ULL, + 1664893052631119222ULL, + 7629576092524553570ULL, + 3485239601103661425ULL, + 9755891797164033838ULL, + 15218148195153269027ULL, + 16460604813734957368ULL, + 9643968136937729763ULL, + 3611348709641382851ULL, + 18256379591337759196ULL, + }, +}; + +void apply_sbox(uint64_t *const state) +{ + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + uint64_t t2 = mult_mod_p(*(state + i), *(state + i)); + uint64_t t4 = mult_mod_p(t2, t2); + + *(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4)); + } +} + +void apply_mds(uint64_t *state) +{ + uint64_t res[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + res[i] = 0; + } + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + for (uint64_t j = 0; j < STATE_WIDTH; j++) + { + res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j))); + } + } + + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + *(state + i) = res[i]; + } +} + +void apply_constants(uint64_t *const state, const uint64_t *ark) +{ + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + *(state + i) = add_mod_p(*(state + i), *(ark + i)); + } +} + +void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res) +{ + for (uint64_t i = 0; i < m; i++) + { + for (uint64_t j = 0; j < STATE_WIDTH; j++) + { + if (i == 0) + { + *(res + j) = mult_mod_p(*(base + j), *(base + j)); + } + else + { + *(res + j) = mult_mod_p(*(res + j), *(res + j)); + } + } + } + + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + *(res + i) = mult_mod_p(*(res + i), *(tail + i)); + } +} + +void apply_inv_sbox(uint64_t *const state) +{ + uint64_t t1[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t1[i] = 0; + } + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t1[i] = mult_mod_p(*(state + i), *(state + i)); + } + + uint64_t t2[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t2[i] = 0; + } + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t2[i] = mult_mod_p(t1[i], t1[i]); + } + + uint64_t t3[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t3[i] = 0; + } + exp_acc(3, t2, t2, t3); + + uint64_t t4[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t4[i] = 0; + } + exp_acc(6, t3, t3, t4); + + uint64_t tmp[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + tmp[i] = 0; + } + exp_acc(12, t4, t4, tmp); + + uint64_t t5[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t5[i] = 0; + } + exp_acc(6, tmp, t3, t5); + + uint64_t t6[STATE_WIDTH]; + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + t6[i] = 0; + } + exp_acc(31, t5, t5, t6); + + for (uint64_t i = 0; i < STATE_WIDTH; i++) + { + uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]); + a = mult_mod_p(a, a); + a = mult_mod_p(a, a); + uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i)); + + *(state + i) = mult_mod_p(a, b); + } +} + +void apply_round(uint64_t *const state, const uint64_t round) +{ + apply_mds(state); + apply_constants(state, ARK1[round]); + apply_sbox(state); + + apply_mds(state); + apply_constants(state, ARK2[round]); + apply_inv_sbox(state); +} + +static void apply_permutation(uint64_t *state) +{ + for (uint64_t i = 0; i < NUM_ROUNDS; i++) + { + apply_round(state, i); + } +} + +/* ================================================================================================ + * RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm. + */ + +#include "rpo.h" + +void rpo128_init(rpo128_context *rc) +{ + rc->dptr = 32; + + memset(rc->st.A, 0, sizeof rc->st.A); +} + +void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len) +{ + size_t dptr; + + dptr = (size_t)rc->dptr; + while (len > 0) + { + size_t clen, u; + + /* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256 + * For RPO, this is 64 * 8 = 512 bits + * The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning + */ + clen = 96 - dptr; + if (clen > len) + { + clen = len; + } + + for (u = 0; u < clen; u++) + { + rc->st.dbuf[dptr + u] = in[u]; + } + + dptr += clen; + in += clen; + len -= clen; + if (dptr == 96) + { + apply_permutation(rc->st.A); + dptr = 32; + } + } + rc->dptr = dptr; +} + +void rpo128_finalize(rpo128_context *rc) +{ + // Set dptr to the end of the buffer, so that first call to extract will call the permutation. + rc->dptr = 96; +} + +void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len) +{ + size_t dptr; + + dptr = (size_t)rc->dptr; + while (len > 0) + { + size_t clen; + + if (dptr == 96) + { + apply_permutation(rc->st.A); + dptr = 32; + } + clen = 96 - dptr; + if (clen > len) + { + clen = len; + } + len -= clen; + + memcpy(out, rc->st.dbuf + dptr, clen); + dptr += clen; + out += clen; + } + rc->dptr = dptr; +} + +void rpo128_release(rpo128_context *rc) +{ + memset(rc->st.A, 0, sizeof rc->st.A); + rc->dptr = 32; +} + +/* ================================================================================================ + * Hash-to-Point algorithm implementation based on RPO128 + */ + +void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn) +{ + /* + * This implementation avoids the rejection sampling step needed in the + * per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf + * page 31, which argues that the current variant is secure for the parameters set by NIST. + * Avoiding the rejection-sampling step leads to an implementation that is constant-time. + * TODO: Check that the current implementation is indeed constant-time. + */ + size_t n; + + n = (size_t)1 << logn; + while (n > 0) + { + uint8_t buf[8]; + uint64_t w; + + rpo128_squeeze(rc, (void *)buf, sizeof buf); + w = ((uint64_t)(buf[7]) << 56) | + ((uint64_t)(buf[6]) << 48) | + ((uint64_t)(buf[5]) << 40) | + ((uint64_t)(buf[4]) << 32) | + ((uint64_t)(buf[3]) << 24) | + ((uint64_t)(buf[2]) << 16) | + ((uint64_t)(buf[1]) << 8) | + ((uint64_t)(buf[0])); + + w %= M; + + *x++ = (uint16_t)w; + n--; + } +} \ No newline at end of file diff --git a/src/dsa/rpo_falcon512/falcon_c/rpo.h b/src/dsa/rpo_falcon512/falcon_c/rpo.h new file mode 100644 index 0000000..d9038af --- /dev/null +++ b/src/dsa/rpo_falcon512/falcon_c/rpo.h @@ -0,0 +1,83 @@ +#include +#include + +/* ================================================================================================ + * RPO hashing algorithm related structs and methods. + */ + +/* + * RPO128 context. + * + * This structure is used by the hashing API. It is composed of an internal state that can be + * viewed as either: + * 1. 12 field elements in the Miden VM. + * 2. 96 bytes. + * + * The first view is used for the internal state in the context of the RPO hashing algorithm. The + * second view is used for the buffer used to absorb the data to be hashed. + * + * The pointer to the buffer is updated as the data is absorbed. + * + * 'rpo128_context' must be initialized with rpo128_init() before first use. + */ +typedef struct +{ + union + { + uint64_t A[12]; + uint8_t dbuf[96]; + } st; + uint64_t dptr; +} rpo128_context; + +/* + * Initializes an RPO state + */ +void rpo128_init(rpo128_context *rc); + +/* + * Absorbs an array of bytes of length 'len' into the state. + */ +void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len); + +/* + * Squeezes an array of bytes of length 'len' from the state. + */ +void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len); + +/* + * Finalizes the state in preparation for squeezing. + * + * This function should be called after all the data has been absorbed. + * + * Note that the current implementation does not perform any sort of padding for domain separation + * purposes. The reason being that, for our purposes, we always perform the following sequence: + * 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements). + * 2. Absorb the message (which is always 4 field elements). + * 3. Call finalize. + * 4. Squeeze the output. + * 5. Call release. + */ +void rpo128_finalize(rpo128_context *rc); + +/* + * Releases the state. + * + * This function should be called after the squeeze operation is finished. + */ +void rpo128_release(rpo128_context *rc); + +/* ================================================================================================ + * Hash-to-Point algorithm for signature generation and signature verification. + */ + +/* + * Hash-to-Point algorithm. + * + * This function generates a point in Z_q[x]/(phi) from a given message. + * + * It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial + * representing the point. The coefficients are stored in the array 'x'. The number of coefficients + * is given by 'logn', which must in our case is 512. + */ +void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn); diff --git a/src/dsa/rpo_falcon512/ffi.rs b/src/dsa/rpo_falcon512/ffi.rs new file mode 100644 index 0000000..f21b3ed --- /dev/null +++ b/src/dsa/rpo_falcon512/ffi.rs @@ -0,0 +1,189 @@ +use libc::c_int; + +// C IMPLEMENTATION INTERFACE +// ================================================================================================ + +extern "C" { + /// Generate a new key pair. Public key goes into pk[], private key in sk[]. + /// Key sizes are exact (in bytes): + /// - public (pk): 897 + /// - private (sk): 1281 + /// + /// Return value: 0 on success, -1 on error. + pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int; + + /// Generate a new key pair from seed. Public key goes into pk[], private key in sk[]. + /// Key sizes are exact (in bytes): + /// - public (pk): 897 + /// - private (sk): 1281 + /// + /// Return value: 0 on success, -1 on error. + pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( + pk: *mut u8, + sk: *mut u8, + seed: *const u8, + ) -> c_int; + + /// Compute a signature on a provided message (m, mlen), with a given private key (sk). + /// Signature is written in sig[], with length written into *siglen. Signature length is + /// variable; maximum signature length (in bytes) is 666. + /// + /// sig[], m[] and sk[] may overlap each other arbitrarily. + /// + /// Return value: 0 on success, -1 on error. + pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( + sig: *mut u8, + siglen: *mut usize, + m: *const u8, + mlen: usize, + sk: *const u8, + ) -> c_int; + + // TEST HELPERS + // -------------------------------------------------------------------------------------------- + + /// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk). + /// + /// sig[], m[] and pk[] may overlap each other arbitrarily. + /// + /// Return value: 0 on success, -1 on error. + #[cfg(test)] + pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( + sig: *const u8, + siglen: usize, + m: *const u8, + mlen: usize, + pk: *const u8, + ) -> c_int; + + /// Hash-to-Point algorithm. + /// + /// This function generates a point in Z_q[x]/(phi) from a given message. + /// + /// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial + /// representing the point. The coefficients are stored in the array 'x'. The number of coefficients + /// is given by 'logn', which must in our case is 512. + #[cfg(test)] + pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo( + rc: *mut Rpo128Context, + x: *mut u16, + logn: usize, + ); + + #[cfg(test)] + pub fn rpo128_init(sc: *mut Rpo128Context); + + #[cfg(test)] + pub fn rpo128_absorb( + sc: *mut Rpo128Context, + data: *const ::std::os::raw::c_void, + len: libc::size_t, + ); + + #[cfg(test)] + pub fn rpo128_finalize(sc: *mut Rpo128Context); +} + +#[repr(C)] +#[cfg(test)] +pub struct Rpo128Context { + pub content: [u64; 13usize], +} + +// TESTS +// ================================================================================================ + +#[cfg(all(test, feature = "std"))] +mod tests { + use super::*; + use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN}; + use rand::Rng; + + #[test] + fn falcon_ffi() { + unsafe { + let mut rng = rand::thread_rng(); + + // --- generate a key pair from a seed ---------------------------- + + let mut pk = [0u8; PK_LEN]; + let mut sk = [0u8; SK_LEN]; + let seed: [u8; NONCE_LEN] = + (0..NONCE_LEN).map(|_| rng.gen()).collect::>().try_into().unwrap(); + + assert_eq!( + 0, + PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( + pk.as_mut_ptr(), + sk.as_mut_ptr(), + seed.as_ptr() + ) + ); + + // --- sign a message and make sure it verifies ------------------- + + let mlen: usize = rng.gen::() as usize; + let msg: Vec = (0..mlen).map(|_| rng.gen()).collect(); + let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN]; + let mut siglen = 0; + + assert_eq!( + 0, + PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( + detached_sig.as_mut_ptr(), + &mut siglen as *mut usize, + msg.as_ptr(), + msg.len(), + sk.as_ptr() + ) + ); + + assert_eq!( + 0, + PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( + detached_sig.as_ptr(), + siglen, + msg.as_ptr(), + msg.len(), + pk.as_ptr() + ) + ); + + // --- check verification of different signature ------------------ + + assert_eq!( + -1, + PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( + detached_sig.as_ptr(), + siglen, + msg.as_ptr(), + msg.len() - 1, + pk.as_ptr() + ) + ); + + // --- check verification against a different pub key ------------- + + let mut pk_alt = [0u8; PK_LEN]; + let mut sk_alt = [0u8; SK_LEN]; + assert_eq!( + 0, + PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( + pk_alt.as_mut_ptr(), + sk_alt.as_mut_ptr() + ) + ); + + assert_eq!( + -1, + PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( + detached_sig.as_ptr(), + siglen, + msg.as_ptr(), + msg.len(), + pk_alt.as_ptr() + ) + ); + } + } +} diff --git a/src/dsa/rpo_falcon512/keys.rs b/src/dsa/rpo_falcon512/keys.rs new file mode 100644 index 0000000..9170245 --- /dev/null +++ b/src/dsa/rpo_falcon512/keys.rs @@ -0,0 +1,227 @@ +use super::{ + ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial, + PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word, +}; + +#[cfg(feature = "std")] +use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN}; + +// PUBLIC KEY +// ================================================================================================ + +/// A public key for verifying signatures. +/// +/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of +/// the polynomial representing the raw bytes of the expanded public key. +/// +/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct PublicKey(Word); + +impl PublicKey { + /// Returns a new [PublicKey] which is a commitment to the provided expanded public key. + /// + /// # Errors + /// Returns an error if the decoding of the public key fails. + pub fn new(pk: PublicKeyBytes) -> Result { + let h = Polynomial::from_pub_key(&pk)?; + let pk_felts = h.to_elements(); + let pk_digest = Rpo256::hash_elements(&pk_felts).into(); + Ok(Self(pk_digest)) + } + + /// Verifies the provided signature against provided message and this public key. + pub fn verify(&self, message: Word, signature: &Signature) -> bool { + signature.verify(message, self.0) + } +} + +impl From for Word { + fn from(key: PublicKey) -> Self { + key.0 + } +} + +// KEY PAIR +// ================================================================================================ + +/// A key pair (public and secret keys) for signing messages. +/// +/// The secret key is a byte array of length [PK_LEN]. +/// The public key is a byte array of length [SK_LEN]. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct KeyPair { + public_key: PublicKeyBytes, + secret_key: SecretKeyBytes, +} + +#[allow(clippy::new_without_default)] +impl KeyPair { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Generates a (public_key, secret_key) key pair from OS-provided randomness. + /// + /// # Errors + /// Returns an error if key generation fails. + #[cfg(feature = "std")] + pub fn new() -> Result { + let mut public_key = [0u8; PK_LEN]; + let mut secret_key = [0u8; SK_LEN]; + + let res = unsafe { + ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( + public_key.as_mut_ptr(), + secret_key.as_mut_ptr(), + ) + }; + + if res == 0 { + Ok(Self { public_key, secret_key }) + } else { + Err(FalconError::KeyGenerationFailed) + } + } + + /// Generates a (public_key, secret_key) key pair from the provided seed. + /// + /// # Errors + /// Returns an error if key generation fails. + #[cfg(feature = "std")] + pub fn from_seed(seed: &NonceBytes) -> Result { + let mut public_key = [0u8; PK_LEN]; + let mut secret_key = [0u8; SK_LEN]; + + let res = unsafe { + ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( + public_key.as_mut_ptr(), + secret_key.as_mut_ptr(), + seed.as_ptr(), + ) + }; + + if res == 0 { + Ok(Self { public_key, secret_key }) + } else { + Err(FalconError::KeyGenerationFailed) + } + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the public key corresponding to this key pair. + pub fn public_key(&self) -> PublicKey { + // TODO: memoize public key commitment as computing it requires quite a bit of hashing. + // expect() is fine here because we assume that the key pair was constructed correctly. + PublicKey::new(self.public_key).expect("invalid key pair") + } + + /// Returns the expanded public key corresponding to this key pair. + pub fn expanded_public_key(&self) -> PublicKeyBytes { + self.public_key + } + + // SIGNATURE GENERATION + // -------------------------------------------------------------------------------------------- + + /// Signs a message with a secret key and a seed. + /// + /// # Errors + /// Returns an error of signature generation fails. + #[cfg(feature = "std")] + pub fn sign(&self, message: Word) -> Result { + let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::>(); + let msg_len = msg.len(); + let mut sig = [0_u8; SIG_LEN + NONCE_LEN]; + let mut sig_len: usize = 0; + + let res = unsafe { + ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( + sig.as_mut_ptr(), + &mut sig_len as *mut usize, + msg.as_ptr(), + msg_len, + self.secret_key.as_ptr(), + ) + }; + + if res == 0 { + Ok(Signature { sig, pk: self.public_key }) + } else { + Err(FalconError::SigGenerationFailed) + } + } +} + +// SERIALIZATION / DESERIALIZATION +// ================================================================================================ + +impl Serializable for KeyPair { + fn write_into(&self, target: &mut W) { + target.write_bytes(&self.public_key); + target.write_bytes(&self.secret_key); + } +} + +impl Deserializable for KeyPair { + fn read_from(source: &mut R) -> Result { + let public_key: PublicKeyBytes = source.read_array()?; + let secret_key: SecretKeyBytes = source.read_array()?; + Ok(Self { public_key, secret_key }) + } +} + +// TESTS +// ================================================================================================ + +#[cfg(all(test, feature = "std"))] +mod tests { + use super::{super::Felt, KeyPair, NonceBytes, Word}; + use rand_utils::{rand_array, rand_vector}; + + #[test] + fn test_falcon_verification() { + // generate random keys + let keys = KeyPair::new().unwrap(); + let pk = keys.public_key(); + + // sign a random message + let message: Word = rand_vector::(4).try_into().expect("Should not fail."); + let signature = keys.sign(message); + + // make sure the signature verifies correctly + assert!(pk.verify(message, signature.as_ref().unwrap())); + + // a signature should not verify against a wrong message + let message2: Word = rand_vector::(4).try_into().expect("Should not fail."); + assert!(!pk.verify(message2, signature.as_ref().unwrap())); + + // a signature should not verify against a wrong public key + let keys2 = KeyPair::new().unwrap(); + assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap())) + } + + #[test] + fn test_falcon_verification_from_seed() { + // generate keys from a random seed + let seed: NonceBytes = rand_array(); + let keys = KeyPair::from_seed(&seed).unwrap(); + let pk = keys.public_key(); + + // sign a random message + let message: Word = rand_vector::(4).try_into().expect("Should not fail."); + let signature = keys.sign(message); + + // make sure the signature verifies correctly + assert!(pk.verify(message, signature.as_ref().unwrap())); + + // a signature should not verify against a wrong message + let message2: Word = rand_vector::(4).try_into().expect("Should not fail."); + assert!(!pk.verify(message2, signature.as_ref().unwrap())); + + // a signature should not verify against a wrong public key + let keys2 = KeyPair::new().unwrap(); + assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap())) + } +} diff --git a/src/dsa/rpo_falcon512/mod.rs b/src/dsa/rpo_falcon512/mod.rs new file mode 100644 index 0000000..5bbe5cf --- /dev/null +++ b/src/dsa/rpo_falcon512/mod.rs @@ -0,0 +1,60 @@ +use crate::{ + hash::rpo::Rpo256, + utils::{ + collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError, + Serializable, + }, + Felt, StarkField, Word, ZERO, +}; + +#[cfg(feature = "std")] +mod ffi; + +mod error; +mod keys; +mod polynomial; +mod signature; + +pub use error::FalconError; +pub use keys::{KeyPair, PublicKey}; +pub use polynomial::Polynomial; +pub use signature::Signature; + +// CONSTANTS +// ================================================================================================ + +// The Falcon modulus. +const MODULUS: u16 = 12289; +const MODULUS_MINUS_1_OVER_TWO: u16 = 6144; + +// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1` +// defining the ring Z_p[x]/(phi). +const N: usize = 512; +const LOG_N: usize = 9; + +/// Length of nonce used for key-pair generation. +const NONCE_LEN: usize = 40; + +/// Number of filed elements used to encode a nonce. +const NONCE_ELEMENTS: usize = 8; + +/// Public key length as a u8 vector. +const PK_LEN: usize = 897; + +/// Secret key length as a u8 vector. +const SK_LEN: usize = 1281; + +/// Signature length as a u8 vector. +const SIG_LEN: usize = 626; + +/// Bound on the squared-norm of the signature. +const SIG_L2_BOUND: u64 = 34034726; + +// TYPE ALIASES +// ================================================================================================ + +type SignatureBytes = [u8; NONCE_LEN + SIG_LEN]; +type PublicKeyBytes = [u8; PK_LEN]; +type SecretKeyBytes = [u8; SK_LEN]; +type NonceBytes = [u8; NONCE_LEN]; +type NonceElements = [Felt; NONCE_ELEMENTS]; diff --git a/src/dsa/rpo_falcon512/polynomial.rs b/src/dsa/rpo_falcon512/polynomial.rs new file mode 100644 index 0000000..fdb9e34 --- /dev/null +++ b/src/dsa/rpo_falcon512/polynomial.rs @@ -0,0 +1,277 @@ +use super::{FalconError, Felt, Vec, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN}; +use core::ops::{Add, Mul, Sub}; + +// FALCON POLYNOMIAL +// ================================================================================================ + +/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1 +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct Polynomial([u16; N]); + +impl Polynomial { + // CONSTRUCTORS + // -------------------------------------------------------------------------------------------- + + /// Constructs a new polynomial from a list of coefficients. + /// + /// # Safety + /// This constructor validates that the coefficients are in the valid range only in debug mode. + pub unsafe fn new(data: [u16; N]) -> Self { + for value in data { + debug_assert!(value < MODULUS); + } + + Self(data) + } + + /// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi). + /// + /// # Errors + /// Returns an error if: + /// - The provided input is not exactly 897 bytes long. + /// - The first byte of the input is not equal to log2(512) i.e., 9. + /// - Any of the coefficients encoded in the provided input is greater than or equal to the + /// Falcon field modulus. + pub fn from_pub_key(input: &[u8]) -> Result { + if input.len() != PK_LEN { + return Err(FalconError::PubKeyDecodingInvalidLength(input.len())); + } + + if input[0] != LOG_N as u8 { + return Err(FalconError::PubKeyDecodingInvalidTag(input[0])); + } + + let mut acc = 0_u32; + let mut acc_len = 0; + + let mut output = [0_u16; N]; + let mut output_idx = 0; + + for &byte in input.iter().skip(1) { + acc = (acc << 8) | (byte as u32); + acc_len += 8; + + if acc_len >= 14 { + acc_len -= 14; + let w = (acc >> acc_len) & 0x3FFF; + if w >= MODULUS as u32 { + return Err(FalconError::PubKeyDecodingInvalidCoefficient(w)); + } + output[output_idx] = w as u16; + output_idx += 1; + } + } + + if (acc & ((1u32 << acc_len) - 1)) == 0 { + Ok(Self(output)) + } else { + Err(FalconError::PubKeyDecodingExtraData) + } + } + + /// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes + /// that the signature has been encoded using the uncompressed format. + /// + /// # Errors + /// Returns an error if: + /// - The signature has been encoded using a different algorithm than the reference compressed + /// encoding algorithm. + /// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512. + /// - While decoding the high bits of a coefficient, the current accumulated value of its + /// high bits is larger than 2048. + /// - The decoded coefficient is -0. + /// - The remaining unused bits in the last byte of `input` are non-zero. + pub fn from_signature(input: &[u8]) -> Result { + let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111); + + if encoding != 0b0011 { + return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm); + } + if log_n != 0b1001 { + return Err(FalconError::SigDecodingNotSupportedDegree(log_n)); + } + + let input = &input[41..]; + let mut input_idx = 0; + let mut acc = 0u32; + let mut acc_len = 0; + let mut output = [0_u16; N]; + + for e in output.iter_mut() { + acc = (acc << 8) | (input[input_idx] as u32); + input_idx += 1; + let b = acc >> acc_len; + let s = b & 128; + let mut m = b & 127; + + loop { + if acc_len == 0 { + acc = (acc << 8) | (input[input_idx] as u32); + input_idx += 1; + acc_len = 8; + } + acc_len -= 1; + if ((acc >> acc_len) & 1) != 0 { + break; + } + m += 128; + if m >= 2048 { + return Err(FalconError::SigDecodingTooBigHighBits(m)); + } + } + if s != 0 && m == 0 { + return Err(FalconError::SigDecodingMinusZero); + } + + *e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 }; + } + + if (acc & ((1 << acc_len) - 1)) != 0 { + return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte); + } + + Ok(Self(output)) + } + + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the coefficients of this polynomial as integers. + pub fn inner(&self) -> [u16; N] { + self.0 + } + + /// Returns the coefficients of this polynomial as field elements. + pub fn to_elements(&self) -> Vec { + self.0.iter().map(|&a| Felt::from(a)).collect() + } + + // POLYNOMIAL OPERATIONS + // -------------------------------------------------------------------------------------------- + + /// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees + /// of the input polynomials are less than 512 and their coefficients are less than the modulus + /// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less + /// than the Miden prime. + /// + /// Note that this multiplication is not over Z_p[x]/(phi). + pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] { + let mut c = [0; 2 * N]; + for i in 0..N { + for j in 0..N { + c[i + j] += a.0[i] as u64 * b.0[j] as u64; + } + } + + c + } + + /// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo + /// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi). + pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self { + let mut c = [0; N]; + for i in 0..N { + let ai = a[N + i] % MODULUS as u64; + let neg_ai = (MODULUS - ai as u16) % MODULUS; + + let bi = (a[i] % MODULUS as u64) as u16; + c[i] = (neg_ai + bi) % MODULUS; + } + + Self(c) + } + + /// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its + /// coefficients to be in the interval (-p/2, p/2]. + pub fn sq_norm(&self) -> u64 { + let mut res = 0; + for e in self.0 { + if e > MODULUS_MINUS_1_OVER_TWO { + res += (MODULUS - e) as u64 * (MODULUS - e) as u64 + } else { + res += e as u64 * e as u64 + } + } + res + } +} + +// Returns a polynomial representing the zero polynomial i.e. default element. +impl Default for Polynomial { + fn default() -> Self { + Self([0_u16; N]) + } +} + +/// Multiplication over Z_p[x]/(phi) +impl Mul for Polynomial { + type Output = Self; + + fn mul(self, other: Self) -> >::Output { + let mut result = [0_u16; N]; + for j in 0..N { + for k in 0..N { + let i = (j + k) % N; + let a = self.0[j] as usize; + let b = other.0[k] as usize; + let q = MODULUS as usize; + let mut prod = a * b % q; + if (N - 1) < (j + k) { + prod = (q - prod) % q; + } + result[i] = ((result[i] as usize + prod) % q) as u16; + } + } + + Polynomial(result) + } +} + +/// Addition over Z_p[x]/(phi) +impl Add for Polynomial { + type Output = Self; + + fn add(self, other: Self) -> >::Output { + let mut res = self; + res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS); + + res + } +} + +/// Subtraction over Z_p[x]/(phi) +impl Sub for Polynomial { + type Output = Self; + + fn sub(self, other: Self) -> >::Output { + let mut res = self; + res.0 + .iter_mut() + .zip(other.0.iter()) + .for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS); + + res + } +} + +// TESTS +// ================================================================================================ + +#[cfg(test)] +mod tests { + use super::{Polynomial, N}; + + #[test] + fn test_negacyclic_reduction() { + let coef1: [u16; N] = rand_utils::rand_array(); + let coef2: [u16; N] = rand_utils::rand_array(); + + let poly1 = Polynomial(coef1); + let poly2 = Polynomial(coef2); + + assert_eq!( + poly1 * poly2, + Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2)) + ); + } +} diff --git a/src/dsa/rpo_falcon512/signature.rs b/src/dsa/rpo_falcon512/signature.rs new file mode 100644 index 0000000..afcde98 --- /dev/null +++ b/src/dsa/rpo_falcon512/signature.rs @@ -0,0 +1,262 @@ +use super::{ + ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements, + Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N, + SIG_L2_BOUND, ZERO, +}; +use crate::utils::string::ToString; + +// FALCON SIGNATURE +// ================================================================================================ + +/// An RPO Falcon512 signature over a message. +/// +/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where: +/// - p := 12289 +/// - phi := x^512 + 1 +/// - s1 = c - s2 * h +/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point +/// of the message being signed. +/// +/// The signature verifies if and only if: +/// 1. s1 = c - s2 * h +/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND +/// +/// where |.| is the norm. +/// +/// [Signature] also includes the extended public key which is serialized as: +/// 1. 1 byte representing the log2(512) i.e., 9. +/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above. +/// +/// The actual signature is serialized as: +/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial +/// together with the degree of the irreducible polynomial phi. +/// The general format of this byte is 0b0cc1nnnn where: +/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the +/// uncompressed algorithm is used. +/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi. +/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and +/// thus the header byte is always equal to 0b00111001. +/// 2. 40 bytes for the nonce. +/// 3. 625 bytes encoding the `s2` polynomial above. +/// +/// The total size of the signature (including the extended public key) is 1563 bytes. +pub struct Signature { + pub(super) pk: PublicKeyBytes, + pub(super) sig: SignatureBytes, +} + +impl Signature { + // PUBLIC ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the public key polynomial h. + pub fn pub_key_poly(&self) -> Polynomial { + // TODO: memoize + // we assume that the signature was constructed with a valid public key, and thus + // expect() is OK here. + Polynomial::from_pub_key(&self.pk).expect("invalid public key") + } + + /// Returns the nonce component of the signature represented as field elements. + /// + /// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks + /// of the nonce and interpreting them as field elements. + pub fn nonce(&self) -> NonceElements { + // we assume that the signature was constructed with a valid signature, and thus + // expect() is OK here. + let nonce = self.sig[1..41].try_into().expect("invalid signature"); + decode_nonce(nonce) + } + + // Returns the polynomial representation of the signature in Z_p[x]/(phi). + pub fn sig_poly(&self) -> Polynomial { + // TODO: memoize + // we assume that the signature was constructed with a valid signature, and thus + // expect() is OK here. + Polynomial::from_signature(&self.sig).expect("invalid signature") + } + + // HASH-TO-POINT + // -------------------------------------------------------------------------------------------- + + /// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message. + pub fn hash_to_point(&self, message: Word) -> Polynomial { + hash_to_point(message, &self.nonce()) + } + + // SIGNATURE VERIFICATION + // -------------------------------------------------------------------------------------------- + /// Returns true if this signature is a valid signature for the specified message generated + /// against key pair matching the specified public key commitment. + pub fn verify(&self, message: Word, pubkey_com: Word) -> bool { + // Make sure the expanded public key matches the provided public key commitment + let h = self.pub_key_poly(); + let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into(); + if h_digest != pubkey_com { + return false; + } + + // Make sure the signature is valid + let s2 = self.sig_poly(); + let c = self.hash_to_point(message); + + let s1 = c - s2 * h; + + let sq_norm = s1.sq_norm() + s2.sq_norm(); + sq_norm <= SIG_L2_BOUND + } +} + +// SERIALIZATION / DESERIALIZATION +// ================================================================================================ + +impl Serializable for Signature { + fn write_into(&self, target: &mut W) { + target.write_bytes(&self.pk); + target.write_bytes(&self.sig); + } +} + +impl Deserializable for Signature { + fn read_from(source: &mut R) -> Result { + let pk: PublicKeyBytes = source.read_array()?; + let sig: SignatureBytes = source.read_array()?; + + // make sure public key and signature can be decoded correctly + Polynomial::from_pub_key(&pk) + .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; + Polynomial::from_signature(&sig[41..]) + .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; + + Ok(Self { pk, sig }) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and +/// nonce. +fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial { + let mut state = [ZERO; Rpo256::STATE_WIDTH]; + + // absorb the nonce into the state + for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) { + *s = n; + } + Rpo256::apply_permutation(&mut state); + + // absorb message into the state + for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) { + *s = m; + } + + // squeeze the coefficients of the polynomial + let mut i = 0; + let mut res = [0_u16; N]; + for _ in 0..64 { + Rpo256::apply_permutation(&mut state); + for a in &state[Rpo256::RATE_RANGE] { + res[i] = (a.as_int() % MODULUS as u64) as u16; + i += 1; + } + } + + // using the raw constructor is OK here because we reduce all coefficients by the modulus above + unsafe { Polynomial::new(res) } +} + +/// Converts byte representation of the nonce into field element representation. +fn decode_nonce(nonce: &NonceBytes) -> NonceElements { + let mut buffer = [0_u8; 8]; + let mut result = [ZERO; 8]; + for (i, bytes) in nonce.chunks(5).enumerate() { + buffer[..5].copy_from_slice(bytes); + result[i] = u64::from_le_bytes(buffer).into(); + } + + result +} + +// TESTS +// ================================================================================================ + +#[cfg(all(test, feature = "std"))] +mod tests { + use super::{ + super::{ffi::*, Felt}, + *, + }; + use libc::c_void; + use rand_utils::rand_vector; + + // Wrappers for unsafe functions + impl Rpo128Context { + /// Initializes the RPO state. + pub fn init() -> Self { + let mut ctx = Rpo128Context { content: [0u64; 13] }; + unsafe { + rpo128_init(&mut ctx as *mut Rpo128Context); + } + ctx + } + + /// Absorbs data into the RPO state. + pub fn absorb(&mut self, data: &[u8]) { + unsafe { + rpo128_absorb( + self as *mut Rpo128Context, + data.as_ptr() as *const c_void, + data.len(), + ) + } + } + + /// Finalizes the RPO state to prepare for squeezing. + pub fn finalize(&mut self) { + unsafe { rpo128_finalize(self as *mut Rpo128Context) } + } + } + + #[test] + fn test_hash_to_point() { + // Create a random message and transform it into a u8 vector + let msg_felts: Word = rand_vector::(4).try_into().unwrap(); + let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::>(); + + // Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array. + let nonce: [u8; 40] = rand_vector::(40).try_into().unwrap(); + + let mut buffer = [0_u8; 64]; + for i in 0..8 { + buffer[8 * i] = nonce[5 * i]; + buffer[8 * i + 1] = nonce[5 * i + 1]; + buffer[8 * i + 2] = nonce[5 * i + 2]; + buffer[8 * i + 3] = nonce[5 * i + 3]; + buffer[8 * i + 4] = nonce[5 * i + 4]; + } + + // Initialize the RPO state + let mut rng = Rpo128Context::init(); + + // Absorb the nonce and message into the RPO state + rng.absorb(&buffer); + rng.absorb(&msg_bytes); + rng.finalize(); + + // Generate the coefficients of the hash-to-point polynomial. + let mut res: [u16; N] = [0; N]; + + unsafe { + PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo( + &mut rng as *mut Rpo128Context, + res.as_mut_ptr(), + 9, + ); + } + + // Check that the coefficients are correct + let nonce = decode_nonce(&nonce); + assert_eq!(res, hash_to_point(msg_felts, &nonce).inner()); + } +} diff --git a/src/lib.rs b/src/lib.rs index cb0e11e..ddf146c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ #[cfg_attr(test, macro_use)] extern crate alloc; +pub mod dsa; pub mod hash; pub mod merkle; pub mod utils; diff --git a/src/merkle/merkle_tree.rs b/src/merkle/merkle_tree.rs index 206543a..769cff1 100644 --- a/src/merkle/merkle_tree.rs +++ b/src/merkle/merkle_tree.rs @@ -371,21 +371,9 @@ mod tests { let nodes: Vec = tree.inner_nodes().collect(); let expected = vec![ - InnerNodeInfo { - value: root, - left: l1n0, - right: l1n1, - }, - InnerNodeInfo { - value: l1n0, - left: l2n0, - right: l2n1, - }, - InnerNodeInfo { - value: l1n1, - left: l2n2, - right: l2n3, - }, + InnerNodeInfo { value: root, left: l1n0, right: l1n1 }, + InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 }, + InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 }, ]; assert_eq!(nodes, expected); diff --git a/src/merkle/mmr/full.rs b/src/merkle/mmr/full.rs index c3dd3ac..af7c212 100644 --- a/src/merkle/mmr/full.rs +++ b/src/merkle/mmr/full.rs @@ -71,10 +71,7 @@ impl Mmr { /// Constructor for an empty `Mmr`. pub fn new() -> Mmr { - Mmr { - forest: 0, - nodes: Vec::new(), - } + Mmr { forest: 0, nodes: Vec::new() } } // ACCESSORS @@ -188,10 +185,7 @@ impl Mmr { .map(|offset| self.nodes[offset - 1]) .collect(); - MmrPeaks { - num_leaves: self.forest, - peaks, - } + MmrPeaks { num_leaves: self.forest, peaks } } /// An iterator over inner nodes in the MMR. The order of iteration is unspecified. diff --git a/src/merkle/mmr/tests.rs b/src/merkle/mmr/tests.rs index 13722e2..3cecaa4 100644 --- a/src/merkle/mmr/tests.rs +++ b/src/merkle/mmr/tests.rs @@ -380,11 +380,7 @@ fn test_mmr_inner_nodes() { left: LEAVES[2], right: LEAVES[3], }, - InnerNodeInfo { - value: h0123, - left: h01, - right: h23, - }, + InnerNodeInfo { value: h0123, left: h01, right: h23 }, InnerNodeInfo { value: h45, left: LEAVES[4], diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index 2c95972..c29b6e8 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -158,11 +158,7 @@ impl PartialMerkleTree { } } - Ok(PartialMerkleTree { - max_depth, - nodes, - leaves, - }) + Ok(PartialMerkleTree { max_depth, nodes, leaves }) } // PUBLIC ACCESSORS diff --git a/src/merkle/path.rs b/src/merkle/path.rs index 2d11bb3..4535ba9 100644 --- a/src/merkle/path.rs +++ b/src/merkle/path.rs @@ -137,11 +137,7 @@ impl<'a> Iterator for InnerNodeIterator<'a> { self.value = Rpo256::merge(&[left, right]); self.index.move_up(); - Some(InnerNodeInfo { - value: self.value, - left, - right, - }) + Some(InnerNodeInfo { value: self.value, left, right }) } else { None } @@ -163,10 +159,7 @@ pub struct ValuePath { impl ValuePath { /// Returns a new [ValuePath] instantiated from the specified value and path. pub fn new(value: RpoDigest, path: Vec) -> Self { - Self { - value, - path: MerklePath::new(path), - } + Self { value, path: MerklePath::new(path) } } } diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index 7c40d90..a951122 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -249,10 +249,7 @@ impl SimpleSmt { fn get_branch_node(&self, index: &NodeIndex) -> BranchNode { self.branches.get(index).cloned().unwrap_or_else(|| { let node = self.empty_hashes[index.depth() as usize + 1]; - BranchNode { - left: node, - right: node, - } + BranchNode { left: node, right: node } }) } diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index 26aaab2..f2c55d1 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -123,21 +123,9 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> { let nodes: Vec = tree.inner_nodes().collect(); let expected = vec![ - InnerNodeInfo { - value: root, - left: l1n0, - right: l1n1, - }, - InnerNodeInfo { - value: l1n0, - left: l2n0, - right: l2n1, - }, - InnerNodeInfo { - value: l1n1, - left: l2n2, - right: l2n3, - }, + InnerNodeInfo { value: root, left: l1n0, right: l1n1 }, + InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 }, + InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 }, ]; assert_eq!(nodes, expected); diff --git a/src/merkle/store/mod.rs b/src/merkle/store/mod.rs index a8dd1fb..caba99e 100644 --- a/src/merkle/store/mod.rs +++ b/src/merkle/store/mod.rs @@ -326,11 +326,9 @@ impl> MerkleStore { /// Iterator over the inner nodes of the [MerkleStore]. pub fn inner_nodes(&self) -> impl Iterator + '_ { - self.nodes.iter().map(|(r, n)| InnerNodeInfo { - value: *r, - left: n.left, - right: n.right, - }) + self.nodes + .iter() + .map(|(r, n)| InnerNodeInfo { value: *r, left: n.left, right: n.right }) } /// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root` @@ -450,13 +448,7 @@ impl> MerkleStore { right_root: RpoDigest, ) -> Result { let parent = Rpo256::merge(&[left_root, right_root]); - self.nodes.insert( - parent, - StoreNode { - left: left_root, - right: right_root, - }, - ); + self.nodes.insert(parent, StoreNode { left: left_root, right: right_root }); Ok(parent) } @@ -551,15 +543,10 @@ impl> FromIterator<(RpoDigest, StoreNode)> for Me // ================================================================================================ impl> Extend for MerkleStore { fn extend>(&mut self, iter: I) { - self.nodes.extend(iter.into_iter().map(|info| { - ( - info.value, - StoreNode { - left: info.left, - right: info.right, - }, - ) - })); + self.nodes.extend( + iter.into_iter() + .map(|info| (info.value, StoreNode { left: info.left, right: info.right })), + ); } } @@ -646,17 +633,12 @@ impl> Deserializable for MerkleStore { /// Creates empty hashes for all the subtrees of a tree with a max depth of 255. fn empty_hashes() -> impl IntoIterator { let subtrees = EmptySubtreeRoots::empty_hashes(255); - subtrees.iter().rev().copied().zip(subtrees.iter().rev().skip(1).copied()).map( - |(child, parent)| { - ( - parent, - StoreNode { - left: child, - right: child, - }, - ) - }, - ) + subtrees + .iter() + .rev() + .copied() + .zip(subtrees.iter().rev().skip(1).copied()) + .map(|(child, parent)| (parent, StoreNode { left: child, right: child })) } /// Consumes an iterator of [InnerNodeInfo] and returns an iterator of `(value, node)` tuples @@ -666,14 +648,6 @@ fn combine_nodes_with_empty_hashes( ) -> impl Iterator { nodes .into_iter() - .map(|info| { - ( - info.value, - StoreNode { - left: info.left, - right: info.right, - }, - ) - }) + .map(|info| (info.value, StoreNode { left: info.left, right: info.right })) .chain(empty_hashes()) } diff --git a/src/merkle/tiered_smt/values.rs b/src/merkle/tiered_smt/values.rs index d41ee6b..4fac76a 100644 --- a/src/merkle/tiered_smt/values.rs +++ b/src/merkle/tiered_smt/values.rs @@ -213,10 +213,7 @@ impl StoreEntry { /// Returns an iterator over all key-value pairs in this entry. pub fn iter(&self) -> impl Iterator { - EntryIterator { - entry: self, - pos: 0, - } + EntryIterator { entry: self, pos: 0 } } // STATE MUTATORS diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8aadabe..7449c89 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -83,10 +83,7 @@ impl std::error::Error for HexParseError {} pub fn hex_to_bytes(value: &str) -> Result<[u8; N], HexParseError> { let expected: usize = (N * 2) + 2; if value.len() != expected { - return Err(HexParseError::InvalidLength { - expected, - got: value.len(), - }); + return Err(HexParseError::InvalidLength { expected, got: value.len() }); } if !value.starts_with("0x") { From cf91c898459d9e45cbecb41e4c047d6202d4705f Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Tue, 3 Oct 2023 18:08:36 -0700 Subject: [PATCH 30/32] refactor: clean up features --- CHANGELOG.md | 2 + Cargo.toml | 17 +++-- README.md | 6 ++ benches/README.md | 7 +- build.rs | 29 ++++++--- .../falcon_c/{falcon_rpo.c => falcon.c} | 65 ++++++++++++------- .../falcon_c/{api_rpo.h => falcon.h} | 0 src/dsa/rpo_falcon512/falcon_c/rpo.c | 26 ++++---- src/dsa/rpo_falcon512/ffi.rs | 12 ++-- src/hash/rpo/mod.rs | 10 +-- src/main.rs | 10 +-- src/merkle/tiered_smt/mod.rs | 2 +- 12 files changed, 110 insertions(+), 76 deletions(-) rename src/dsa/rpo_falcon512/falcon_c/{falcon_rpo.c => falcon.c} (92%) rename src/dsa/rpo_falcon512/falcon_c/{api_rpo.h => falcon.h} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ec4658..d5ca98e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,9 +3,11 @@ * Replaced `MerklePathSet` with `PartialMerkleTree` (#165). * Implemented clearing of nodes in `TieredSmt` (#173). * Added ability to generate inclusion proofs for `TieredSmt` (#174). +* Implemented Falcon DSA (#179). * Added conditional `serde`` support for various structs (#180). * Implemented benchmarking for `TieredSmt` (#182). * Added more leaf traversal methods for `MerkleStore` (#185). +* Added SVE acceleration for RPO hash function (#189). ## 0.6.0 (2023-06-25) diff --git a/Cargo.toml b/Cargo.toml index c77a1d0..49ddaa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,22 +32,21 @@ name = "store" harness = false [features] -arch-arm64-sve = ["dep:cc"] -default = ["blake3/default", "std", "winter_crypto/default", "winter_math/default", "winter_utils/default"] +default = ["std"] executable = ["dep:clap", "dep:rand_utils", "std"] -std = ["blake3/std", "dep:cc", "dep:libc", "dep:rand", "winter_crypto/std", "winter_math/std", "winter_utils/std"] -serde = ["winter_math/serde", "dep:serde", "serde/alloc"] +serde = ["dep:serde", "serde?/alloc", "winter_math/serde"] +std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std", "winter_utils/std"] +sve = ["std"] [dependencies] blake3 = { version = "1.4", default-features = false } -clap = { version = "4.3", features = ["derive"], optional = true} -libc = { version = "0.2", optional = true, default-features = false } -rand = { version = "0.8", optional = true, default-features = false } +clap = { version = "4.3", features = ["derive"], optional = true } +libc = { version = "0.2", default-features = false, optional = true } +rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } +serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true } winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } winter_math = { version = "0.6", package = "winter-math", default-features = false } winter_utils = { version = "0.6", package = "winter-utils", default-features = false } -serde = { version = "1.0", features = [ "derive" ], optional = true, default-features = false } -rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/README.md b/README.md index 6274cea..1f9d27e 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,12 @@ Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/ To compile with `no_std`, disable default features via `--no-default-features` flag. +### SVE acceleration +On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` feature enabled. This feature has an effect only if the platform exposes `target-feature=sve` flag. On some platforms (e.g., Graviton 3), for this flag to be set, the compilation must be done in "native" mode. For example, to enable SVE acceleration on Graviton 3, we can execute the following: +```shell +RUSTFLAGS="-C target-cpu=native" cargo build --release --features sve +``` + ## Testing You can use cargo defaults to test the library: diff --git a/benches/README.md b/benches/README.md index 385e01e..a1dddd0 100644 --- a/benches/README.md +++ b/benches/README.md @@ -19,7 +19,7 @@ The second scenario is that of sequential hashing where we take a sequence of le | ------------------- | ------ | --------| --------- | --------- | ------- | | Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us | | Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us | -| Amazon Graviton 3 | 116 ns | | | | 8.8 us | +| Amazon Graviton 3 | 108 ns | | | | 5.3 us | | AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us | | Intel Core i5-8279U | 80 ns | | | | 8.7 us | | Intel Xeon 8375C | 67 ns | | | | 8.2 us | @@ -30,11 +30,14 @@ The second scenario is that of sequential hashing where we take a sequence of le | ------------------- | -------| ------- | --------- | --------- | ------- | | Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us | | Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us | -| Amazon Graviton 3 | 1.4 us | | | | 114 us | +| Amazon Graviton 3 | 1.4 us | | | | 69 us | | AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us | | Intel Core i5-8279U | 1.0 us | | | | 116 us | | Intel Xeon 8375C | 0.8 ns | | | | 110 us | +Notes: +- On Graviton 3, RPO256 is run with SVE acceleration enabled. + ### Instructions Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following: diff --git a/build.rs b/build.rs index e27f9df..8fbd225 100644 --- a/build.rs +++ b/build.rs @@ -2,7 +2,7 @@ fn main() { #[cfg(feature = "std")] compile_rpo_falcon(); - #[cfg(feature = "arch-arm64-sve")] + #[cfg(all(target_feature = "sve", feature = "sve"))] compile_arch_arm64_sve(); } @@ -10,31 +10,40 @@ fn main() { fn compile_rpo_falcon() { use std::path::PathBuf; + const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c"; + + println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h"); + println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c"); + println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h"); + println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c"); + let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect(); let common_dir: PathBuf = ["PQClean", "common"].iter().collect(); - let rpo_dir: PathBuf = ["src", "dsa", "rpo_falcon512", "falcon_c"].iter().collect(); let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap(); let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap(); - let rpo_files = glob::glob(rpo_dir.join("*.c").to_str().unwrap()).unwrap(); cc::Build::new() .include(&common_dir) .include(target_dir) .files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned())) .files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned())) - .files(rpo_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned())) - .compile("falcon-512_clean"); + .file(format!("{RPO_FALCON_PATH}/falcon.c")) + .file(format!("{RPO_FALCON_PATH}/rpo.c")) + .flag("-O3") + .compile("rpo_falcon512"); } -#[cfg(feature = "arch-arm64-sve")] +#[cfg(all(target_feature = "sve", feature = "sve"))] fn compile_arch_arm64_sve() { - println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/library.c"); - println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/library.h"); - println!("cargo:rerun-if-changed=arch/arm64-sve/rpo/rpo_hash.h"); + const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo"; + + println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.c"); + println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.h"); + println!("cargo:rerun-if-changed={RPO_SVE_PATH}/rpo_hash.h"); cc::Build::new() - .file("arch/arm64-sve/rpo/library.c") + .file(format!("{RPO_SVE_PATH}/library.c")) .flag("-march=armv8-a+sve") .flag("-O3") .compile("rpo_sve"); diff --git a/src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c b/src/dsa/rpo_falcon512/falcon_c/falcon.c similarity index 92% rename from src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c rename to src/dsa/rpo_falcon512/falcon_c/falcon.c index a1294d2..cd7bed5 100644 --- a/src/dsa/rpo_falcon512/falcon_c/falcon_rpo.c +++ b/src/dsa/rpo_falcon512/falcon_c/falcon.c @@ -4,7 +4,7 @@ #include #include "randombytes.h" -#include "api_rpo.h" +#include "falcon.h" #include "inner.h" #include "rpo.h" @@ -37,10 +37,12 @@ * (signature length is 1+len(value), not counting the nonce) */ -/* see api_rpo.h */ +/* see falcon.h */ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( - uint8_t *pk, uint8_t *sk, unsigned char *seed) -{ + uint8_t *pk, + uint8_t *sk, + unsigned char *seed +) { union { uint8_t b[FALCON_KEYGEN_TEMP_9]; @@ -111,8 +113,9 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( } int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( - uint8_t *pk, uint8_t *sk) -{ + uint8_t *pk, + uint8_t *sk +) { unsigned char seed[48]; /* @@ -137,10 +140,14 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( * * Return value: 0 on success, -1 on error. */ -static int -do_sign(uint8_t *nonce, uint8_t *sigbuf, size_t *sigbuflen, - const uint8_t *m, size_t mlen, const uint8_t *sk) -{ +static int do_sign( + uint8_t *nonce, + uint8_t *sigbuf, + size_t *sigbuflen, + const uint8_t *m, + size_t mlen, + const uint8_t *sk +) { union { uint8_t b[72 * 512]; @@ -261,11 +268,14 @@ do_sign(uint8_t *nonce, uint8_t *sigbuf, size_t *sigbuflen, * (of size sigbuflen) contains the signature value, not including the * header byte or nonce. Return value is 0 on success, -1 on error. */ -static int -do_verify( - const uint8_t *nonce, const uint8_t *sigbuf, size_t sigbuflen, - const uint8_t *m, size_t mlen, const uint8_t *pk) -{ +static int do_verify( + const uint8_t *nonce, + const uint8_t *sigbuf, + size_t sigbuflen, + const uint8_t *m, + size_t mlen, + const uint8_t *pk +) { union { uint8_t b[2 * 512]; @@ -341,11 +351,14 @@ do_verify( return 0; } -/* see api_rpo.h */ +/* see falcon.h */ int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( - uint8_t *sig, size_t *siglen, - const uint8_t *m, size_t mlen, const uint8_t *sk) -{ + uint8_t *sig, + size_t *siglen, + const uint8_t *m, + size_t mlen, + const uint8_t *sk +) { /* * The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for * the signed message object (as produced by crypto_sign()) @@ -369,11 +382,14 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo( return 0; } -/* see api_rpo.h */ +/* see falcon.h */ int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( - const uint8_t *sig, size_t siglen, - const uint8_t *m, size_t mlen, const uint8_t *pk) -{ + const uint8_t *sig, + size_t siglen, + const uint8_t *m, + size_t mlen, + const uint8_t *pk +) { if (siglen < 1 + NONCELEN) { return -1; @@ -382,6 +398,5 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo( { return -1; } - return do_verify(sig + 1, - sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk); + return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk); } diff --git a/src/dsa/rpo_falcon512/falcon_c/api_rpo.h b/src/dsa/rpo_falcon512/falcon_c/falcon.h similarity index 100% rename from src/dsa/rpo_falcon512/falcon_c/api_rpo.h rename to src/dsa/rpo_falcon512/falcon_c/falcon.h diff --git a/src/dsa/rpo_falcon512/falcon_c/rpo.c b/src/dsa/rpo_falcon512/falcon_c/rpo.c index 824f8a2..a33e81d 100644 --- a/src/dsa/rpo_falcon512/falcon_c/rpo.c +++ b/src/dsa/rpo_falcon512/falcon_c/rpo.c @@ -14,7 +14,7 @@ #define M 12289 // From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go -uint64_t add_mod_p(uint64_t a, uint64_t b) +static uint64_t add_mod_p(uint64_t a, uint64_t b) { a = P - a; uint64_t res = b - a; @@ -23,7 +23,7 @@ uint64_t add_mod_p(uint64_t a, uint64_t b) return res; } -uint64_t sub_mod_p(uint64_t a, uint64_t b) +static uint64_t sub_mod_p(uint64_t a, uint64_t b) { uint64_t r = a - b; if (a < b) @@ -31,7 +31,7 @@ uint64_t sub_mod_p(uint64_t a, uint64_t b) return r; } -uint64_t reduce_mod_p(uint64_t b, uint64_t a) +static uint64_t reduce_mod_p(uint64_t b, uint64_t a) { uint32_t d = b >> 32, c = b; @@ -43,7 +43,7 @@ uint64_t reduce_mod_p(uint64_t b, uint64_t a) return a; } -uint64_t mult_mod_p(uint64_t x, uint64_t y) +static uint64_t mult_mod_p(uint64_t x, uint64_t y) { uint32_t a = x, b = x >> 32, @@ -85,7 +85,7 @@ static const uint64_t NUM_ROUNDS = 7; /* * MDS matrix */ -const uint64_t MDS[12][12] = { +static const uint64_t MDS[12][12] = { { 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 }, { 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 }, { 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 }, @@ -103,7 +103,7 @@ const uint64_t MDS[12][12] = { /* * Round constants. */ -const uint64_t ARK1[7][12] = { +static const uint64_t ARK1[7][12] = { { 5789762306288267392ULL, 6522564764413701783ULL, @@ -304,7 +304,7 @@ const uint64_t ARK2[7][12] = { }, }; -void apply_sbox(uint64_t *const state) +static void apply_sbox(uint64_t *const state) { for (uint64_t i = 0; i < STATE_WIDTH; i++) { @@ -315,7 +315,7 @@ void apply_sbox(uint64_t *const state) } } -void apply_mds(uint64_t *state) +static void apply_mds(uint64_t *state) { uint64_t res[STATE_WIDTH]; for (uint64_t i = 0; i < STATE_WIDTH; i++) @@ -336,7 +336,7 @@ void apply_mds(uint64_t *state) } } -void apply_constants(uint64_t *const state, const uint64_t *ark) +static void apply_constants(uint64_t *const state, const uint64_t *ark) { for (uint64_t i = 0; i < STATE_WIDTH; i++) { @@ -344,7 +344,7 @@ void apply_constants(uint64_t *const state, const uint64_t *ark) } } -void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res) +static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res) { for (uint64_t i = 0; i < m; i++) { @@ -367,7 +367,7 @@ void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint6 } } -void apply_inv_sbox(uint64_t *const state) +static void apply_inv_sbox(uint64_t *const state) { uint64_t t1[STATE_WIDTH]; for (uint64_t i = 0; i < STATE_WIDTH; i++) @@ -435,7 +435,7 @@ void apply_inv_sbox(uint64_t *const state) } } -void apply_round(uint64_t *const state, const uint64_t round) +static void apply_round(uint64_t *const state, const uint64_t round) { apply_mds(state); apply_constants(state, ARK1[round]); @@ -579,4 +579,4 @@ void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, *x++ = (uint16_t)w; n--; } -} \ No newline at end of file +} diff --git a/src/dsa/rpo_falcon512/ffi.rs b/src/dsa/rpo_falcon512/ffi.rs index f21b3ed..4508ff2 100644 --- a/src/dsa/rpo_falcon512/ffi.rs +++ b/src/dsa/rpo_falcon512/ffi.rs @@ -3,6 +3,7 @@ use libc::c_int; // C IMPLEMENTATION INTERFACE // ================================================================================================ +#[link(name = "rpo_falcon512", kind = "static")] extern "C" { /// Generate a new key pair. Public key goes into pk[], private key in sk[]. /// Key sizes are exact (in bytes): @@ -97,19 +98,18 @@ pub struct Rpo128Context { mod tests { use super::*; use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN}; - use rand::Rng; + use rand_utils::{rand_array, rand_value, rand_vector}; #[test] fn falcon_ffi() { unsafe { - let mut rng = rand::thread_rng(); + //let mut rng = rand::thread_rng(); // --- generate a key pair from a seed ---------------------------- let mut pk = [0u8; PK_LEN]; let mut sk = [0u8; SK_LEN]; - let seed: [u8; NONCE_LEN] = - (0..NONCE_LEN).map(|_| rng.gen()).collect::>().try_into().unwrap(); + let seed: [u8; NONCE_LEN] = rand_array(); assert_eq!( 0, @@ -122,8 +122,8 @@ mod tests { // --- sign a message and make sure it verifies ------------------- - let mlen: usize = rng.gen::() as usize; - let msg: Vec = (0..mlen).map(|_| rng.gen()).collect(); + let mlen: usize = rand_value::() as usize; + let msg: Vec = rand_vector(mlen); let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN]; let mut siglen = 0; diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs index dc7df3f..fafce89 100644 --- a/src/hash/rpo/mod.rs +++ b/src/hash/rpo/mod.rs @@ -10,7 +10,7 @@ use mds_freq::mds_multiply_freq; #[cfg(test)] mod tests; -#[cfg(feature = "arch-arm64-sve")] +#[cfg(all(target_feature = "sve", feature = "sve"))] #[link(name = "rpo_sve", kind = "static")] extern "C" { fn add_constants_and_apply_sbox( @@ -375,7 +375,7 @@ impl Rpo256 { // -------------------------------------------------------------------------------------------- #[inline(always)] - #[cfg(feature = "arch-arm64-sve")] + #[cfg(all(target_feature = "sve", feature = "sve"))] fn optimized_add_constants_and_apply_sbox( state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH], @@ -386,7 +386,7 @@ impl Rpo256 { } #[inline(always)] - #[cfg(not(feature = "arch-arm64-sve"))] + #[cfg(not(all(target_feature = "sve", feature = "sve")))] fn optimized_add_constants_and_apply_sbox( _state: &mut [Felt; STATE_WIDTH], _ark: &[Felt; STATE_WIDTH], @@ -395,7 +395,7 @@ impl Rpo256 { } #[inline(always)] - #[cfg(feature = "arch-arm64-sve")] + #[cfg(all(target_feature = "sve", feature = "sve"))] fn optimized_add_constants_and_apply_inv_sbox( state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH], @@ -409,7 +409,7 @@ impl Rpo256 { } #[inline(always)] - #[cfg(not(feature = "arch-arm64-sve"))] + #[cfg(not(all(target_feature = "sve", feature = "sve")))] fn optimized_add_constants_and_apply_inv_sbox( _state: &mut [Felt; STATE_WIDTH], _ark: &[Felt; STATE_WIDTH], diff --git a/src/main.rs b/src/main.rs index 800306c..e9f8299 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,23 +31,23 @@ pub fn benchmark_tsmt() { let tree_size = args.size; // prepare the `leaves` vector for tree creation - let mut leaves = Vec::new(); + let mut entries = Vec::new(); for i in 0..tree_size { let key = rand_value::(); let value = [ONE, ONE, ONE, Felt::new(i)]; - leaves.push((key, value)); + entries.push((key, value)); } - let mut tree = construction(leaves, tree_size).unwrap(); + let mut tree = construction(entries, tree_size).unwrap(); insertion(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap(); } /// Runs the construction benchmark for the Tiered SMT, returning the constructed tree. -pub fn construction(leaves: Vec<(RpoDigest, Word)>, size: u64) -> Result { +pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result { println!("Running a construction benchmark:"); let now = Instant::now(); - let tree = TieredSmt::with_leaves(leaves)?; + let tree = TieredSmt::with_entries(entries)?; let elapsed = now.elapsed(); println!( "Constructed a TSMT with {} key-value pairs in {:.3} seconds", diff --git a/src/merkle/tiered_smt/mod.rs b/src/merkle/tiered_smt/mod.rs index 7fbb2ca..63e4aa8 100644 --- a/src/merkle/tiered_smt/mod.rs +++ b/src/merkle/tiered_smt/mod.rs @@ -73,7 +73,7 @@ impl TieredSmt { /// /// # Errors /// Returns an error if the provided entries contain multiple values for the same key. - pub fn with_leaves(entries: R) -> Result + pub fn with_entries(entries: R) -> Result where R: IntoIterator, I: Iterator + ExactSizeIterator, From aeadc96b0521f3a347b28191fba7ed2481c12bd7 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Fri, 6 Oct 2023 00:11:32 -0700 Subject: [PATCH 31/32] docs: add signature section to main readme --- Cargo.toml | 8 ++++---- README.md | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 49ddaa8..8e4b8db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ documentation = "https://docs.rs/miden-crypto/0.7.0" categories = ["cryptography", "no-std"] keywords = ["miden", "crypto", "hash", "merkle"] edition = "2021" -rust-version = "1.67" +rust-version = "1.73" [[bin]] name = "miden-crypto" @@ -39,8 +39,8 @@ std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std sve = ["std"] [dependencies] -blake3 = { version = "1.4", default-features = false } -clap = { version = "4.3", features = ["derive"], optional = true } +blake3 = { version = "1.5", default-features = false } +clap = { version = "4.4", features = ["derive"], optional = true } libc = { version = "0.2", default-features = false, optional = true } rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true } @@ -50,7 +50,7 @@ winter_utils = { version = "0.6", package = "winter-utils", default-features = f [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } -proptest = "1.1.0" +proptest = "1.3" rand_utils = { version = "0.6", package = "winter-rand-utils" } [build-dependencies] diff --git a/README.md b/README.md index 1f9d27e..156a1b0 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,13 @@ For performance benchmarks of these hash functions and their comparison to other The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state. +## Signatures +[DAS module](./src/dsa) provides a set of digital signature schemes supported by default in Miden VM. Currently, these schemes are: + +* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM. + +For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well. + ## Crate features This crate can be compiled with the following features: From 9235a78afd743bf863757820984351a4e86c01ea Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare Date: Fri, 6 Oct 2023 17:06:06 -0700 Subject: [PATCH 32/32] chore: add date for v0.7 release --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5ca98e..3372864 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.7.0 (TBD) +## 0.7.0 (2023-10-05) * Replaced `MerklePathSet` with `PartialMerkleTree` (#165). * Implemented clearing of nodes in `TieredSmt` (#173).