Compare commits

...

13 commits

Author SHA1 Message Date
9fe2673a33 fixup doc comment 2025-02-28 15:38:42 +01:00
fa0943fc48 WIP: smt: add SparseMerklePath 2025-02-28 15:38:32 +01:00
bf595cc5c0 init 2025-02-28 15:36:56 +01:00
901816f139 WIP: smt: factor out MerklePath logic 2025-02-27 17:30:16 +01:00
609898efeb WIP: smt: add SparseMerklePath 2025-02-27 17:30:16 +01:00
1e87cd60ff
docs: add SMT benchmarks (#384) 2025-02-25 13:33:19 -08:00
b97243c582
fix: dead_code warning on pairs_to_leaf when not(feature = "concurrent") (#380)
This also moves `pairs_to_leaf()` out of the `SparseMerkleTree` trait,
also removing it from `SimpleSmt`, as `pairs_to_leaf()` is only ever
used in concurrent code for `Smt`.

This fixes a warning with `--no-default-features`.
2025-02-24 10:26:56 -08:00
Philipp Gackstatter
d0e9ead6f4
feat: filter empty values in Smt::with_entries (#383) 2025-02-18 02:18:47 -08:00
Bobbin Threadbare
2ba30bf3bf
fix: error in Cargo.lock 2025-02-18 01:19:15 -08:00
Bobbin Threadbare
0514a8316a
Merge branch 'main' into next 2025-02-18 01:14:07 -08:00
Bobbin Threadbare
8ce7b68d68
chore: increment crate version to v0.13.3 and update changelog 2025-02-18 01:09:11 -08:00
Philipp Gackstatter
535637d7fb
fix: panic in PartialMmr::untrack (#382) 2025-02-18 01:04:21 -08:00
Philipp Gackstatter
ed14eaa90c
fix: PartialSmt stale proofs not resulting in error (#381) 2025-02-17 15:40:56 -08:00
101 changed files with 1 additions and 27088 deletions

View file

@ -1,3 +0,0 @@
[profile.default]
failure-output = "immediate-final"
fail-fast = false

View file

@ -1,20 +0,0 @@
# Documentation available at editorconfig.org
root=true
[*]
ident_style = space
ident_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.rs]
max_line_length = 100
[*.md]
trim_trailing_whitespace = false
[*.yml]
ident_size = 2

View file

@ -1,9 +0,0 @@
## Describe your changes
## Checklist before requesting a review
- Repo forked and branch created from `next` according to naming convention.
- Commit messages and codestyle follow [conventions](./CONTRIBUTING.md).
- Relevant issues are linked in the PR description.
- Tests added for new functionality.
- Documentation/comments updated according to changes.

View file

@ -1,25 +0,0 @@
# Runs build related jobs.
name: build
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
no-std:
name: Build for no-std
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
steps:
- uses: actions/checkout@main
- name: Build for no-std
run: |
rustup update --no-self-update ${{ matrix.toolchain }}
rustup target add wasm32-unknown-unknown
make build-no-std

View file

@ -1,23 +0,0 @@
# Runs changelog related jobs.
# CI job heavily inspired by: https://github.com/tarides/changelog-check-action
name: changelog
on:
pull_request:
types: [opened, reopened, synchronize, labeled, unlabeled]
jobs:
changelog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@main
with:
fetch-depth: 0
- name: Check for changes in changelog
env:
BASE_REF: ${{ github.event.pull_request.base.ref }}
NO_CHANGELOG_LABEL: ${{ contains(github.event.pull_request.labels.*.name, 'no changelog') }}
run: ./scripts/check-changelog.sh "${{ inputs.changelog }}"
shell: bash

View file

@ -1,53 +0,0 @@
# Runs linting related jobs.
name: lint
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
clippy:
name: clippy nightly on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Clippy
run: |
rustup update --no-self-update nightly
rustup +nightly component add clippy
make clippy
rustfmt:
name: rustfmt check nightly on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Rustfmt
run: |
rustup update --no-self-update nightly
rustup +nightly component add rustfmt
make format-check
doc:
name: doc stable on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Build docs
run: |
rustup update --no-self-update
make doc
version:
name: check rust version consistency
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
with:
profile: minimal
override: true
- name: check rust versions
run: ./scripts/check-rust-version.sh

View file

@ -1,28 +0,0 @@
# Runs test related jobs.
name: test
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
test:
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
args: [default, smt-hashmaps, no-std]
timeout-minutes: 30
steps:
- uses: actions/checkout@main
- uses: taiki-e/install-action@nextest
- name: Perform tests
run: |
rustup update --no-self-update ${{matrix.toolchain}}
make test-${{matrix.args}}

12
.gitignore vendored
View file

@ -1,12 +0,0 @@
# Generated by Cargo
# will have compiled files and executables
/target/
# These are backup files generated by rustfmt
**/*.rs.bk
# Generated by cmake
cmake-build-*
# VS Code
.vscode/

0
.gitmodules vendored
View file

View file

@ -1,34 +0,0 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-json
- id: check-toml
- id: pretty-format-json
- id: check-added-large-files
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: detect-private-key
- repo: local
hooks:
- id: lint
name: Make lint
stages: [commit]
language: rust
entry: make lint
- id: doc
name: Make doc
stages: [commit]
language: rust
entry: make doc
- id: check
name: Make check
stages: [commit]
language: rust
entry: make check

View file

@ -1,186 +0,0 @@
## 0.14.0 (TBD)
- [BREAKING] Increment minimum supported Rust version to 1.84.
- Removed duplicated check in RpoFalcon512 verification (#368).
- Added parallel implementation of `Smt::compute_mutations` with better performance (#365).
- Implemented parallel leaf hashing in `Smt::process_sorted_pairs_to_leaves` (#365).
- [BREAKING] Updated Winterfell dependency to v0.12 (#374).
- Added debug-only duplicate column check in `build_subtree` (#378).
## 0.13.3 (2025-02-12)
- Implement `PartialSmt` (#372).
## 0.13.2 (2025-01-24)
- Made `InnerNode` and `NodeMutation` public. Implemented (de)serialization of `LeafIndex` (#367).
## 0.13.1 (2024-12-26)
- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
## 0.13.0 (2024-11-24)
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
- Added support for hashmaps in `Smt` and `SimpleSmt` which gives up to 10x boost in some operations (#363).
## 0.12.0 (2024-10-30)
- [BREAKING] Updated Winterfell dependency to v0.10 (#338).
- Added parallel implementation of `Smt::with_entries()` with significantly better performance when the `concurrent` feature is enabled (#341).
## 0.11.0 (2024-10-17)
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
- Standardized CI and Makefile across Miden repos (#323).
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
- Changed padding rule for RPO/RPX hash functions (#318).
- [BREAKING] Changed return value of the `Mmr::verify()` and `MerklePath::verify()` from `bool` to `Result<>` (#335).
- Added `is_empty()` functions to the `SimpleSmt` and `Smt` structures. Added `EMPTY_ROOT` constant to the `SparseMerkleTree` trait (#337).
## 0.10.3 (2024-09-25)
- Implement `get_size_hint` for `Smt` (#331).
## 0.10.2 (2024-09-25)
- Implement `get_size_hint` for `RpoDigest` and `RpxDigest` and expose constants for their serialized size (#330).
## 0.10.1 (2024-09-13)
- Added `Serializable` and `Deserializable` implementations for `PartialMmr` and `InOrderIndex` (#329).
## 0.10.0 (2024-08-06)
- Added more `RpoDigest` and `RpxDigest` conversions (#311).
- [BREAKING] Migrated to Winterfell v0.9 (#315).
- Fixed encoding of Falcon secret key (#319).
## 0.9.3 (2024-04-24)
- Added `RpxRandomCoin` struct (#307).
## 0.9.2 (2024-04-21)
- Implemented serialization for the `Smt` struct (#304).
- Fixed a bug in Falcon signature generation (#305).
## 0.9.1 (2024-04-02)
- Added `num_leaves()` method to `SimpleSmt` (#302).
## 0.9.0 (2024-03-24)
- [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290).
- [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285).
- [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299).
# 0.8.4 (2024-03-17)
- Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros).
# 0.8.3 (2024-03-17)
- Re-added unintentionally removed re-exported liballoc macros (#292).
# 0.8.2 (2024-03-17)
- Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290).
## 0.8.1 (2024-02-21)
- Fixed clippy warnings (#280)
## 0.8.0 (2024-02-14)
- Implemented the `PartialMmr` data structure (#195).
- Implemented RPX hash function (#201).
- Added `FeltRng` and `RpoRandomCoin` (#237).
- Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
- Added `inner_nodes()` method to `PartialMmr` (#238).
- Improved `PartialMmr::apply_delta()` (#242).
- Refactored `SimpleSmt` struct (#245).
- Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
- Updated Winterfell dependency to v0.8 (#275).
## 0.7.1 (2023-10-10)
- Fixed RPO Falcon signature build on Windows.
## 0.7.0 (2023-10-05)
- Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
- Implemented clearing of nodes in `TieredSmt` (#173).
- Added ability to generate inclusion proofs for `TieredSmt` (#174).
- Implemented Falcon DSA (#179).
- Added conditional `serde`` support for various structs (#180).
- Implemented benchmarking for `TieredSmt` (#182).
- Added more leaf traversal methods for `MerkleStore` (#185).
- Added SVE acceleration for RPO hash function (#189).
## 0.6.0 (2023-06-25)
- [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
- [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
- Added initial implementation of `PartialMerkleTree` (#156).
## 0.5.0 (2023-05-26)
- Implemented `TieredSmt` (#152, #153).
- Implemented ability to extract a subset of a `MerkleStore` (#151).
- Cleaned up `SimpleSmt` interface (#149).
- Decoupled hashing and padding of peaks in `Mmr` (#148).
- Added `inner_nodes()` to `MerkleStore` (#146).
## 0.4.0 (2023-04-21)
- Exported `MmrProof` from the crate (#137).
- Allowed merging of leaves in `MerkleStore` (#138).
- [BREAKING] Refactored how existing data structures are added to `MerkleStore` (#139).
## 0.3.0 (2023-04-08)
- Added `depth` parameter to SMT constructors in `MerkleStore` (#115).
- Optimized MMR peak hashing for Miden VM (#120).
- Added `get_leaf_depth` method to `MerkleStore` (#119).
- Added inner node iterators to `MerkleTree`, `SimpleSmt`, and `Mmr` (#117, #118, #121).
## 0.2.0 (2023-03-24)
- Implemented `Mmr` and related structs (#67).
- Implemented `MerkleStore` (#93, #94, #95, #107 #112).
- Added benchmarks for `MerkleStore` vs. other structs (#97).
- Added Merkle path containers (#99).
- Fixed depth handling in `MerklePathSet` (#110).
- Updated Winterfell dependency to v0.6.
## 0.1.4 (2023-02-22)
- Re-export winter-crypto Hasher, Digest & ElementHasher (#72)
## 0.1.3 (2023-02-20)
- Updated Winterfell dependency to v0.5.1 (#68)
## 0.1.2 (2023-02-17)
- Fixed `Rpo256::hash` pad that was panicking on input (#44)
- Added `MerklePath` wrapper to encapsulate Merkle opening verification and root computation (#53)
- Added `NodeIndex` Merkle wrapper to encapsulate Merkle tree traversal and mappings (#54)
## 0.1.1 (2023-02-06)
- Introduced `merge_in_domain` for the RPO hash function, to allow using a specified domain value in the second capacity register when hashing two digests together.
- Added a simple sparse Merkle tree implementation.
- Added re-exports of Winterfell RandomCoin and RandomCoinError.
## 0.1.0 (2022-12-02)
- Initial release on crates.io containing the cryptographic primitives used in Miden VM and the Miden Rollup.
- Hash module with the BLAKE3 and Rescue Prime Optimized hash functions.
- BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output.
- RPO is implemented with 256-bit output.
- Merkle module, with a set of data structures related to Merkle trees, implemented using the RPO hash function.

View file

@ -1,108 +0,0 @@
# Contributing to Crypto
#### First off, thanks for taking the time to contribute!
We want to make contributing to this project as easy and transparent as possible, whether it's:
- Reporting a [bug](https://github.com/0xPolygonMiden/crypto/issues/new)
- Taking part in [discussions](https://github.com/0xPolygonMiden/crypto/discussions)
- Submitting a [fix](https://github.com/0xPolygonMiden/crypto/pulls)
- Proposing new [features](https://github.com/0xPolygonMiden/crypto/issues/new)
&nbsp;
## Flow
We are using [Github Flow](https://docs.github.com/en/get-started/quickstart/github-flow), so all code changes happen through pull requests from a [forked repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo).
### Branching
- The current active branch is `next`. Every branch with a fix/feature must be forked from `next`.
- The branch name should contain a short issue/feature description separated with hyphens [(kebab-case)](https://en.wikipedia.org/wiki/Letter_case#Kebab_case).
For example, if the issue title is `Fix functionality X in component Y` then the branch name will be something like: `fix-x-in-y`.
- New branch should be rebased from `next` before submitting a PR in case there have been changes to avoid merge commits.
i.e. this branches state:
```
A---B---C fix-x-in-y
/
D---E---F---G next
| |
(F, G) changes happened after `fix-x-in-y` forked
```
should become this after rebase:
```
A'--B'--C' fix-x-in-y
/
D---E---F---G next
```
More about rebase [here](https://git-scm.com/docs/git-rebase) and [here](https://www.atlassian.com/git/tutorials/rewriting-history/git-rebase#:~:text=What%20is%20git%20rebase%3F,of%20a%20feature%20branching%20workflow.)
### Commit messages
- Commit messages should be written in a short, descriptive manner and be prefixed with tags for the change type and scope (if possible) according to the [semantic commit](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716) scheme.
For example, a new change to the AIR crate might have the following message: `feat(air): add constraints for the decoder`
- Also squash commits to logically separated, distinguishable stages to keep git log clean:
```
7hgf8978g9... Added A to X \
\ (squash)
gh354354gh... oops, typo --- * ---------> 9fh1f51gh7... feat(X): add A && B
/
85493g2458... Added B to X /
789fdfffdf... Fixed D in Y \
\ (squash)
787g8fgf78... blah blah --- * ---------> 4070df6f00... fix(Y): fixed D && C
/
9080gf6567... Fixed C in Y /
```
### Code Style and Documentation
- For documentation in the codebase, we follow the [rustdoc](https://doc.rust-lang.org/rust-by-example/meta/doc.html) convention with no more than 100 characters per line.
- For code sections, we use code separators like the following to a width of 100 characters::
```
// CODE SECTION HEADER
// ================================================================================
```
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's preferable to run linting locally before push:
```
cargo fix --allow-staged --allow-dirty --all-targets --all-features; cargo fmt; cargo clippy --workspace --all-targets --all-features -- -D warnings
```
### Versioning
We use [semver](https://semver.org/) naming convention.
&nbsp;
## Pre-PR checklist
1. Repo forked and branch created from `next` according to the naming convention.
2. Commit messages and code style follow conventions.
3. Tests added for new functionality.
4. Documentation/comments updated for all changes according to our documentation convention.
5. Clippy and Rustfmt linting passed.
6. New branch rebased from `next`.
&nbsp;
## Write bug reports with detail, background, and sample code
**Great Bug Reports** tend to have:
- A quick summary and/or background
- Steps to reproduce
- What you expected would happen
- What actually happens
- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
&nbsp;
## Any contributions you make will be under the MIT Software License
In short, when you submit code changes, your submissions are understood to be under the same [MIT License](http://choosealicense.com/licenses/mit/) that covers the project. Feel free to contact the maintainers if that's a concern.

1297
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,93 +0,0 @@
[package]
name = "miden-crypto"
version = "0.14.0"
description = "Miden Cryptographic primitives"
authors = ["miden contributors"]
readme = "README.md"
license = "MIT"
repository = "https://github.com/0xPolygonMiden/crypto"
documentation = "https://docs.rs/miden-crypto/0.14.0"
categories = ["cryptography", "no-std"]
keywords = ["miden", "crypto", "hash", "merkle"]
edition = "2021"
rust-version = "1.84"
[[bin]]
name = "miden-crypto"
path = "src/main.rs"
bench = false
doctest = false
required-features = ["executable"]
[[bench]]
name = "hash"
harness = false
[[bench]]
name = "smt"
harness = false
[[bench]]
name = "smt-subtree"
harness = false
required-features = ["internal"]
[[bench]]
name = "merkle"
harness = false
[[bench]]
name = "smt-with-entries"
harness = false
[[bench]]
name = "store"
harness = false
[features]
concurrent = ["dep:rayon", "hashbrown?/rayon"]
default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"]
smt_hashmaps = ["dep:hashbrown"]
internal = []
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [
"blake3/std",
"dep:cc",
"rand/std",
"rand/std_rng",
"winter-crypto/std",
"winter-math/std",
"winter-utils/std",
]
[dependencies]
blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] }
hashbrown = { version = "0.15", optional = true, features = ["serde"] }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false }
rand_core = { version = "0.6", default-features = false }
rand-utils = { version = "0.12", package = "winter-rand-utils", optional = true }
rayon = { version = "1.10", optional = true }
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
sha3 = { version = "0.10", default-features = false }
thiserror = { version = "2.0", default-features = false }
winter-crypto = { version = "0.12", default-features = false }
winter-math = { version = "0.12", default-features = false }
winter-utils = { version = "0.12", default-features = false }
[dev-dependencies]
assert_matches = { version = "1.5", default-features = false }
criterion = { version = "0.5", features = ["html_reports"] }
getrandom = { version = "0.2", features = ["js"] }
hex = { version = "0.4", default-features = false, features = ["alloc"] }
proptest = "1.6"
rand_chacha = { version = "0.3", default-features = false }
rand-utils = { version = "0.12", package = "winter-rand-utils" }
seq-macro = { version = "0.3" }
[build-dependencies]
cc = { version = "1.2", optional = true, features = ["parallel"] }
glob = "0.3"

21
LICENSE
View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2025 Polygon Miden
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,93 +0,0 @@
.DEFAULT_GOAL := help
.PHONY: help
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
# -- variables --------------------------------------------------------------------------------------
WARNINGS=RUSTDOCFLAGS="-D warnings"
DEBUG_OVERFLOW_INFO=RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2"
# -- linting --------------------------------------------------------------------------------------
.PHONY: clippy
clippy: ## Run Clippy with configs
$(WARNINGS) cargo +nightly clippy --workspace --all-targets --all-features
.PHONY: fix
fix: ## Run Fix with configs
cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features
.PHONY: format
format: ## Run Format using nightly toolchain
cargo +nightly fmt --all
.PHONY: format-check
format-check: ## Run Format using nightly toolchain but only in check mode
cargo +nightly fmt --all --check
.PHONY: lint
lint: format fix clippy ## Run all linting tasks at once (Clippy, fixing, formatting)
# --- docs ----------------------------------------------------------------------------------------
.PHONY: doc
doc: ## Generate and check documentation
$(WARNINGS) cargo doc --all-features --keep-going --release
# --- testing -------------------------------------------------------------------------------------
.PHONY: test-default
test-default: ## Run tests with default features
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
.PHONY: test-smt-hashmaps
test-smt-hashmaps: ## Run tests with `smt_hashmaps` feature enabled
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --features smt_hashmaps
.PHONY: test-no-std
test-no-std: ## Run tests with `no-default-features` (std)
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --no-default-features
.PHONY: test
test: test-default test-smt-hashmaps test-no-std ## Run all tests
# --- checking ------------------------------------------------------------------------------------
.PHONY: check
check: ## Check all targets and features for errors without code generation
cargo check --all-targets --all-features
# --- building ------------------------------------------------------------------------------------
.PHONY: build
build: ## Build with default features enabled
cargo build --release
.PHONY: build-no-std
build-no-std: ## Build without the standard library
cargo build --release --no-default-features --target wasm32-unknown-unknown
.PHONY: build-avx2
build-avx2: ## Build with avx2 support
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
.PHONY: build-sve
build-sve: ## Build with sve support
RUSTFLAGS="-C target-feature=+sve" cargo build --release
# --- benchmarking --------------------------------------------------------------------------------
.PHONY: bench
bench: ## Run crypto benchmarks
cargo bench --features concurrent
.PHONY: bench-smt-concurrent
bench-smt-concurrent: ## Run SMT benchmarks with concurrent feature
cargo run --release --features concurrent,executable -- --size 1000000

110
README.md
View file

@ -1,110 +0,0 @@
# Miden Crypto
[![LICENSE](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE)
[![test](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml)
[![build](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml)
[![RUST_VERSION](https://img.shields.io/badge/rustc-1.84+-lightgray.svg)](https://www.rust-lang.org/tools/install)
[![CRATE](https://img.shields.io/crates/v/miden-crypto)](https://crates.io/crates/miden-crypto)
This crate contains cryptographic primitives used in Polygon Miden.
## Hash
[Hash module](./src/hash) provides a set of cryptographic hash functions which are used by the Miden VM and the Miden rollup. Currently, these functions are:
- [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
- [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
- [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
## Merkle
[Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are:
- `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
- `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
- `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
- `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
- `PartialMmr`: a partial view of a Merkle mountain range structure.
- `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
- `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
## Signatures
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
- `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the _hash-to-point_ algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator.
## Pseudo-Random Element Generator
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
- `FeltRng`: a trait for generating random field elements and random 4 field elements.
- `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPO hash function.
- `RpxRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPX hash function.
## Make commands
We use `make` to automate building, testing, and other processes. In most cases, `make` commands are wrappers around `cargo` commands with specific arguments. You can view the list of available commands in the [Makefile](Makefile), or run the following command:
```shell
make
```
## Crate features
This crate can be compiled with the following features:
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
- `std` - enabled by default and relies on the Rust standard library.
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
- `smt_hashmaps` - uses hashbrown hashmaps in SMT implementation which significantly improves performance of SMT updating. Keys ordering in SMT iterators is not guarantied when this feature is enabled.
All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command:
```shell
make build-no-std
```
### AVX2 acceleration
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
```shell
make build-avx2
```
### SVE acceleration
On platforms with [SVE](<https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)>) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
```shell
make build-sve
```
## Testing
The best way to test the library is using our [Makefile](Makefile), this will enable you to use our pre-defined optimized testing commands:
```shell
make test
```
For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile](Makefile)):
```shell
RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release
```
## License
This project is [MIT licensed](./LICENSE).

View file

@ -1,55 +0,0 @@
#include <stddef.h>
#include <arm_sve.h>
#include "library.h"
#include "rpo_hash_128bit.h"
#include "rpo_hash_256bit.h"
// The STATE_WIDTH of RPO hash is 12x u64 elements.
// The current generation of SVE-enabled processors - Neoverse V1
// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64)
// This allows us to split the state into 3 vectors of 4 elements
// and process all 3 independent of each other.
// We see the biggest performance gains by leveraging both
// vector and scalar operations on parts of the state array.
// Due to high latency of vector operations, the processor is able
// to reorder and pipeline scalar instructions while we wait for
// vector results. This effectively gives us some 'free' scalar
// operations and masks vector latency.
//
// This also means that we can fully saturate all four arithmetic
// units of the processor (2x scalar, 2x SIMD)
//
// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS
// GAIN WIDER REGISTERS. It's quite possible that with 8x u64
// vectors processing 2 partially filled vectors might
// be easier and faster than dealing with scalar operations
// on the remainder of the array.
//
// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back
// to the regular, already highly-optimized scalar version
// if the conditions are not met.
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector
if (vl == 2) {
return add_constants_and_apply_sbox_128(state, constants);
} else if (vl == 4) {
return add_constants_and_apply_sbox_256(state, constants);
} else {
return false;
}
}
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector
if (vl == 2) {
return add_constants_and_apply_inv_sbox_128(state, constants);
} else if (vl == 4) {
return add_constants_and_apply_inv_sbox_256(state, constants);
} else {
return false;
}
}

View file

@ -1,12 +0,0 @@
#ifndef CRYPTO_LIBRARY_H
#define CRYPTO_LIBRARY_H
#include <stdint.h>
#include <stdbool.h>
#define STATE_WIDTH 12
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]);
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]);
#endif //CRYPTO_LIBRARY_H

View file

@ -1,318 +0,0 @@
#ifndef RPO_SVE_RPO_HASH_128_H
#define RPO_SVE_RPO_HASH_128_H
#include <arm_sve.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#define STATE_WIDTH 12
#define COPY_128(NAME, VIN1, VIN2, VIN3, VIN4, SIN) \
svuint64_t NAME ## _1 = VIN1; \
svuint64_t NAME ## _2 = VIN2; \
svuint64_t NAME ## _3 = VIN3; \
svuint64_t NAME ## _4 = VIN4; \
uint64_t NAME ## _tail[4]; \
memcpy(NAME ## _tail, SIN, 4 * sizeof(uint64_t))
#define MULTIPLY_128(PRED, DEST, OP) \
mul_128(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, &DEST ## _3, &OP ## _3, &DEST ## _4, &OP ## _4, DEST ## _tail, OP ## _tail)
#define SQUARE_128(PRED, NAME) \
sq_128(PRED, &NAME ## _1, &NAME ## _2, &NAME ## _3, &NAME ## _4, NAME ## _tail)
#define SQUARE_DEST_128(PRED, DEST, SRC) \
COPY_128(DEST, SRC ## _1, SRC ## _2, SRC ## _3, SRC ## _4, SRC ## _tail); \
SQUARE_128(PRED, DEST);
#define POW_ACC_128(PRED, NAME, CNT, TAIL) \
for (size_t i = 0; i < CNT; i++) { \
SQUARE_128(PRED, NAME); \
} \
MULTIPLY_128(PRED, NAME, TAIL);
#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \
COPY_128(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3, HEAD ## _4, HEAD ## _tail); \
POW_ACC_128(PRED, DEST, CNT, TAIL)
extern inline void add_constants_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *const1,
svuint64_t *state2,
svuint64_t *const2,
svuint64_t *state3,
svuint64_t *const3,
svuint64_t *state4,
svuint64_t *const4,
uint64_t *state_tail,
uint64_t *const_tail
) {
uint64_t Ms = 0xFFFFFFFF00000001ull;
svuint64_t Mv = svindex_u64(Ms, 0);
uint64_t p_1 = Ms - const_tail[0];
uint64_t p_2 = Ms - const_tail[1];
uint64_t p_3 = Ms - const_tail[2];
uint64_t p_4 = Ms - const_tail[3];
uint64_t x_1, x_2, x_3, x_4;
uint32_t adj_1 = -__builtin_sub_overflow(state_tail[0], p_1, &x_1);
uint32_t adj_2 = -__builtin_sub_overflow(state_tail[1], p_2, &x_2);
uint32_t adj_3 = -__builtin_sub_overflow(state_tail[2], p_3, &x_3);
uint32_t adj_4 = -__builtin_sub_overflow(state_tail[3], p_4, &x_4);
state_tail[0] = x_1 - (uint64_t)adj_1;
state_tail[1] = x_2 - (uint64_t)adj_2;
state_tail[2] = x_3 - (uint64_t)adj_3;
state_tail[3] = x_4 - (uint64_t)adj_4;
svuint64_t p1 = svsub_x(pg, Mv, *const1);
svuint64_t p2 = svsub_x(pg, Mv, *const2);
svuint64_t p3 = svsub_x(pg, Mv, *const3);
svuint64_t p4 = svsub_x(pg, Mv, *const4);
svuint64_t x1 = svsub_x(pg, *state1, p1);
svuint64_t x2 = svsub_x(pg, *state2, p2);
svuint64_t x3 = svsub_x(pg, *state3, p3);
svuint64_t x4 = svsub_x(pg, *state4, p4);
svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
svbool_t pt3 = svcmplt_u64(pg, *state3, p3);
svbool_t pt4 = svcmplt_u64(pg, *state4, p4);
*state1 = svsub_m(pt1, x1, (uint32_t)-1);
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
*state3 = svsub_m(pt3, x3, (uint32_t)-1);
*state4 = svsub_m(pt4, x4, (uint32_t)-1);
}
extern inline void mul_128(
svbool_t pg,
svuint64_t *r1,
const svuint64_t *op1,
svuint64_t *r2,
const svuint64_t *op2,
svuint64_t *r3,
const svuint64_t *op3,
svuint64_t *r4,
const svuint64_t *op4,
uint64_t *r_tail,
const uint64_t *op_tail
) {
__uint128_t x_1 = r_tail[0];
__uint128_t x_2 = r_tail[1];
__uint128_t x_3 = r_tail[2];
__uint128_t x_4 = r_tail[3];
x_1 *= (__uint128_t) op_tail[0];
x_2 *= (__uint128_t) op_tail[1];
x_3 *= (__uint128_t) op_tail[2];
x_4 *= (__uint128_t) op_tail[3];
uint64_t x0_1 = x_1;
uint64_t x0_2 = x_2;
uint64_t x0_3 = x_3;
uint64_t x0_4 = x_4;
svuint64_t l1 = svmul_x(pg, *r1, *op1);
svuint64_t l2 = svmul_x(pg, *r2, *op2);
svuint64_t l3 = svmul_x(pg, *r3, *op3);
svuint64_t l4 = svmul_x(pg, *r4, *op4);
uint64_t x1_1 = (x_1 >> 64);
uint64_t x1_2 = (x_2 >> 64);
uint64_t x1_3 = (x_3 >> 64);
uint64_t x1_4 = (x_4 >> 64);
uint64_t a_1, a_2, a_3, a_4;
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);
svuint64_t ls1 = svlsl_x(pg, l1, 32);
svuint64_t ls2 = svlsl_x(pg, l2, 32);
svuint64_t ls3 = svlsl_x(pg, l3, 32);
svuint64_t ls4 = svlsl_x(pg, l4, 32);
svuint64_t a1 = svadd_x(pg, l1, ls1);
svuint64_t a2 = svadd_x(pg, l2, ls2);
svuint64_t a3 = svadd_x(pg, l3, ls3);
svuint64_t a4 = svadd_x(pg, l4, ls4);
svbool_t e1 = svcmplt(pg, a1, l1);
svbool_t e2 = svcmplt(pg, a2, l2);
svbool_t e3 = svcmplt(pg, a3, l3);
svbool_t e4 = svcmplt(pg, a4, l4);
svuint64_t as1 = svlsr_x(pg, a1, 32);
svuint64_t as2 = svlsr_x(pg, a2, 32);
svuint64_t as3 = svlsr_x(pg, a3, 32);
svuint64_t as4 = svlsr_x(pg, a4, 32);
svuint64_t b1 = svsub_x(pg, a1, as1);
svuint64_t b2 = svsub_x(pg, a2, as2);
svuint64_t b3 = svsub_x(pg, a3, as3);
svuint64_t b4 = svsub_x(pg, a4, as4);
b1 = svsub_m(e1, b1, 1);
b2 = svsub_m(e2, b2, 1);
b3 = svsub_m(e3, b3, 1);
b4 = svsub_m(e4, b4, 1);
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;
uint64_t r_1, r_2, r_3, r_4;
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);
svuint64_t h1 = svmulh_x(pg, *r1, *op1);
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
svuint64_t h3 = svmulh_x(pg, *r3, *op3);
svuint64_t h4 = svmulh_x(pg, *r4, *op4);
svuint64_t tr1 = svsub_x(pg, h1, b1);
svuint64_t tr2 = svsub_x(pg, h2, b2);
svuint64_t tr3 = svsub_x(pg, h3, b3);
svuint64_t tr4 = svsub_x(pg, h4, b4);
svbool_t c1 = svcmplt_u64(pg, h1, b1);
svbool_t c2 = svcmplt_u64(pg, h2, b2);
svbool_t c3 = svcmplt_u64(pg, h3, b3);
svbool_t c4 = svcmplt_u64(pg, h4, b4);
*r1 = svsub_m(c1, tr1, (uint32_t) -1);
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
*r3 = svsub_m(c3, tr3, (uint32_t) -1);
*r4 = svsub_m(c4, tr4, (uint32_t) -1);
uint32_t minus1_1 = 0 - c_1;
uint32_t minus1_2 = 0 - c_2;
uint32_t minus1_3 = 0 - c_3;
uint32_t minus1_4 = 0 - c_4;
r_tail[0] = r_1 - (uint64_t)minus1_1;
r_tail[1] = r_2 - (uint64_t)minus1_2;
r_tail[2] = r_3 - (uint64_t)minus1_3;
r_tail[3] = r_4 - (uint64_t)minus1_4;
}
extern inline void sq_128(svbool_t pg, svuint64_t *a, svuint64_t *b, svuint64_t *c, svuint64_t *d, uint64_t *e) {
mul_128(pg, a, a, b, b, c, c, d, d, e, e);
}
extern inline void apply_sbox_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
svuint64_t *state3,
svuint64_t *state4,
uint64_t *state_tail
) {
COPY_128(x, *state1, *state2, *state3, *state4, state_tail); // copy input to x
SQUARE_128(pg, x); // x contains input^2
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^3
SQUARE_128(pg, x); // x contains input^4
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^7
}
extern inline void apply_inv_sbox_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
svuint64_t *state3,
svuint64_t *state4,
uint64_t *state_tail
) {
// base^10
COPY_128(t1, *state1, *state2, *state3, *state4, state_tail);
SQUARE_128(pg, t1);
// base^100
SQUARE_DEST_128(pg, t2, t1);
// base^100100
POW_ACC_DEST(pg, t3, 3, t2, t2);
// base^100100100100
POW_ACC_DEST(pg, t4, 6, t3, t3);
// compute base^100100100100100100100100
POW_ACC_DEST(pg, t5, 12, t4, t4);
// compute base^100100100100100100100100100100
POW_ACC_DEST(pg, t6, 6, t5, t3);
// compute base^1001001001001001001001001001000100100100100100100100100100100
POW_ACC_DEST(pg, t7, 31, t6, t6);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
SQUARE_128(pg, t7);
MULTIPLY_128(pg, t7, t6);
SQUARE_128(pg, t7);
SQUARE_128(pg, t7);
MULTIPLY_128(pg, t7, t1);
MULTIPLY_128(pg, t7, t2);
mul_128(pg, state1, &t7_1, state2, &t7_2, state3, &t7_3, state4, &t7_4, state_tail, t7_tail);
}
bool add_constants_and_apply_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
svuint64_t state4 = svld1(ptrue, state + 3 * vl);
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);
add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
apply_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);
svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
svst1(ptrue, state + 2 * vl, state3);
svst1(ptrue, state + 3 * vl, state4);
return true;
}
bool add_constants_and_apply_inv_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
svuint64_t state4 = svld1(ptrue, state + 3 * vl);
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);
add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
apply_inv_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);
svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
svst1(ptrue, state + 2 * vl, state3);
svst1(ptrue, state + 3 * vl, state4);
return true;
}
#endif //RPO_SVE_RPO_HASH_128_H

View file

@ -1,261 +0,0 @@
#ifndef RPO_SVE_RPO_HASH_256_H
#define RPO_SVE_RPO_HASH_256_H
#include <arm_sve.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#define STATE_WIDTH 12
#define COPY_256(NAME, VIN1, VIN2, SIN3) \
svuint64_t NAME ## _1 = VIN1; \
svuint64_t NAME ## _2 = VIN2; \
uint64_t NAME ## _3[4]; \
memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t))
#define MULTIPLY_256(PRED, DEST, OP) \
mul_256(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3)
#define SQUARE_256(PRED, NAME) \
sq_256(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3)
#define SQUARE_DEST_256(PRED, DEST, SRC) \
COPY_256(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \
SQUARE_256(PRED, DEST);
#define POW_ACC(PRED, NAME, CNT, TAIL) \
for (size_t i = 0; i < CNT; i++) { \
SQUARE_256(PRED, NAME); \
} \
MULTIPLY_256(PRED, NAME, TAIL);
#define POW_ACC_DEST_256(PRED, DEST, CNT, HEAD, TAIL) \
COPY_256(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \
POW_ACC(PRED, DEST, CNT, TAIL)
extern inline void add_constants_256(
svbool_t pg,
svuint64_t *state1,
svuint64_t *const1,
svuint64_t *state2,
svuint64_t *const2,
uint64_t *state3,
uint64_t *const3
) {
uint64_t Ms = 0xFFFFFFFF00000001ull;
svuint64_t Mv = svindex_u64(Ms, 0);
uint64_t p_1 = Ms - const3[0];
uint64_t p_2 = Ms - const3[1];
uint64_t p_3 = Ms - const3[2];
uint64_t p_4 = Ms - const3[3];
uint64_t x_1, x_2, x_3, x_4;
uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1);
uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2);
uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3);
uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4);
state3[0] = x_1 - (uint64_t)adj_1;
state3[1] = x_2 - (uint64_t)adj_2;
state3[2] = x_3 - (uint64_t)adj_3;
state3[3] = x_4 - (uint64_t)adj_4;
svuint64_t p1 = svsub_x(pg, Mv, *const1);
svuint64_t p2 = svsub_x(pg, Mv, *const2);
svuint64_t x1 = svsub_x(pg, *state1, p1);
svuint64_t x2 = svsub_x(pg, *state2, p2);
svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
*state1 = svsub_m(pt1, x1, (uint32_t)-1);
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
}
extern inline void mul_256(
svbool_t pg,
svuint64_t *r1,
const svuint64_t *op1,
svuint64_t *r2,
const svuint64_t *op2,
uint64_t *r3,
const uint64_t *op3
) {
__uint128_t x_1 = r3[0];
__uint128_t x_2 = r3[1];
__uint128_t x_3 = r3[2];
__uint128_t x_4 = r3[3];
x_1 *= (__uint128_t) op3[0];
x_2 *= (__uint128_t) op3[1];
x_3 *= (__uint128_t) op3[2];
x_4 *= (__uint128_t) op3[3];
uint64_t x0_1 = x_1;
uint64_t x0_2 = x_2;
uint64_t x0_3 = x_3;
uint64_t x0_4 = x_4;
svuint64_t l1 = svmul_x(pg, *r1, *op1);
svuint64_t l2 = svmul_x(pg, *r2, *op2);
uint64_t x1_1 = (x_1 >> 64);
uint64_t x1_2 = (x_2 >> 64);
uint64_t x1_3 = (x_3 >> 64);
uint64_t x1_4 = (x_4 >> 64);
uint64_t a_1, a_2, a_3, a_4;
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);
svuint64_t ls1 = svlsl_x(pg, l1, 32);
svuint64_t ls2 = svlsl_x(pg, l2, 32);
svuint64_t a1 = svadd_x(pg, l1, ls1);
svuint64_t a2 = svadd_x(pg, l2, ls2);
svbool_t e1 = svcmplt(pg, a1, l1);
svbool_t e2 = svcmplt(pg, a2, l2);
svuint64_t as1 = svlsr_x(pg, a1, 32);
svuint64_t as2 = svlsr_x(pg, a2, 32);
svuint64_t b1 = svsub_x(pg, a1, as1);
svuint64_t b2 = svsub_x(pg, a2, as2);
b1 = svsub_m(e1, b1, 1);
b2 = svsub_m(e2, b2, 1);
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;
uint64_t r_1, r_2, r_3, r_4;
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);
svuint64_t h1 = svmulh_x(pg, *r1, *op1);
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
svuint64_t tr1 = svsub_x(pg, h1, b1);
svuint64_t tr2 = svsub_x(pg, h2, b2);
svbool_t c1 = svcmplt_u64(pg, h1, b1);
svbool_t c2 = svcmplt_u64(pg, h2, b2);
*r1 = svsub_m(c1, tr1, (uint32_t) -1);
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
uint32_t minus1_1 = 0 - c_1;
uint32_t minus1_2 = 0 - c_2;
uint32_t minus1_3 = 0 - c_3;
uint32_t minus1_4 = 0 - c_4;
r3[0] = r_1 - (uint64_t)minus1_1;
r3[1] = r_2 - (uint64_t)minus1_2;
r3[2] = r_3 - (uint64_t)minus1_3;
r3[3] = r_4 - (uint64_t)minus1_4;
}
extern inline void sq_256(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) {
mul_256(pg, a, a, b, b, c, c);
}
extern inline void apply_sbox_256(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
uint64_t *state3
) {
COPY_256(x, *state1, *state2, state3); // copy input to x
SQUARE_256(pg, x); // x contains input^2
mul_256(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3
SQUARE_256(pg, x); // x contains input^4
mul_256(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7
}
extern inline void apply_inv_sbox_256(
svbool_t pg,
svuint64_t *state_1,
svuint64_t *state_2,
uint64_t *state_3
) {
// base^10
COPY_256(t1, *state_1, *state_2, state_3);
SQUARE_256(pg, t1);
// base^100
SQUARE_DEST_256(pg, t2, t1);
// base^100100
POW_ACC_DEST_256(pg, t3, 3, t2, t2);
// base^100100100100
POW_ACC_DEST_256(pg, t4, 6, t3, t3);
// compute base^100100100100100100100100
POW_ACC_DEST_256(pg, t5, 12, t4, t4);
// compute base^100100100100100100100100100100
POW_ACC_DEST_256(pg, t6, 6, t5, t3);
// compute base^1001001001001001001001001001000100100100100100100100100100100
POW_ACC_DEST_256(pg, t7, 31, t6, t6);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
SQUARE_256(pg, t7);
MULTIPLY_256(pg, t7, t6);
SQUARE_256(pg, t7);
SQUARE_256(pg, t7);
MULTIPLY_256(pg, t7, t1);
MULTIPLY_256(pg, t7, t2);
mul_256(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3);
}
bool add_constants_and_apply_sbox_256(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 4; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
add_constants_256(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
apply_sbox_256(ptrue, &state1, &state2, state + 8);
svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
return true;
}
bool add_constants_and_apply_inv_sbox_256(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 4; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
add_constants_256(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
apply_inv_sbox_256(ptrue, &state1, &state2, state + 8);
svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
return true;
}
#endif //RPO_SVE_RPO_HASH_256_H

View file

@ -1,58 +0,0 @@
# Miden VM Hash Functions
In the Miden VM, we make use of different hash functions. Some of these are "traditional" hash functions, like `BLAKE3`, which are optimized for out-of-STARK performance, while others are algebraic hash functions, like `Rescue Prime`, and are more optimized for a better performance inside the STARK. In what follows, we benchmark several such hash functions and compare against other constructions that are used by other proving systems. More precisely, we benchmark:
* **BLAKE3** as specified [here](https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf) and implemented [here](https://github.com/BLAKE3-team/BLAKE3) (with a wrapper exposed via this crate).
* **SHA3** as specified [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/sha/mod.rs).
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
## Comparison and Instructions
### Comparison
We benchmark the above hash functions using two scenarios. The first is a 2-to-1 $(a,b)\mapsto h(a,b)$ hashing where both $a$, $b$ and $h(a,b)$ are the digests corresponding to each of the hash functions.
The second scenario is that of sequential hashing where we take a sequence of length $100$ field elements and hash these to produce a single digest. The digests are $4$ field elements in a prime field with modulus $2^{64} - 2^{32} + 1$ (i.e., 32 bytes) for Poseidon, Rescue Prime and RPO, and an array `[u8; 32]` for SHA3 and BLAKE3.
#### Scenario 1: 2-to-1 hashing `h(a,b)`
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
| Amazon Graviton 4 | 96 ns | | | | 5.1 µs | 2.8 µs |
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
| AMD EPYC 9R14 | 83 ns | | | | 4.3 µs | 2.4 µs |
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
| Amazon Graviton 4 | 1.2 µs | | | | 67 µs | 36 µs |
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
| AMD EPYC 9R14 | 0.9 µs | | | | 56 µs | 32 µs |
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
Notes:
- On Graviton 3 and 4, RPO256 and RPX256 are run with SVE acceleration enabled.
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
### Instructions
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
```
cargo bench hash
```
To run the benchmarks for Rescue Prime, Poseidon and SHA3, clone the following [repository](https://github.com/Dominik1999/winterfell.git) as above, then checkout the `hash-functions-benches` branch, and from the root directory of the repo run the following:
```
cargo bench hash
```

View file

@ -1,161 +0,0 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use miden_crypto::{
hash::{
blake::Blake3_256,
rpo::{Rpo256, RpoDigest},
rpx::{Rpx256, RpxDigest},
},
Felt,
};
use rand_utils::rand_value;
use winter_crypto::Hasher;
fn rpo256_2to1(c: &mut Criterion) {
let v: [RpoDigest; 2] = [Rpo256::hash(&[1_u8]), Rpo256::hash(&[2_u8])];
c.bench_function("RPO256 2-to-1 hashing (cached)", |bench| {
bench.iter(|| Rpo256::merge(black_box(&v)))
});
c.bench_function("RPO256 2-to-1 hashing (random)", |bench| {
bench.iter_batched(
|| {
[
Rpo256::hash(&rand_value::<u64>().to_le_bytes()),
Rpo256::hash(&rand_value::<u64>().to_le_bytes()),
]
},
|state| Rpo256::merge(&state),
BatchSize::SmallInput,
)
});
}
fn rpo256_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
c.bench_function("RPO256 sequential hashing (cached)", |bench| {
bench.iter(|| Rpo256::hash_elements(black_box(&v)))
});
c.bench_function("RPO256 sequential hashing (random)", |bench| {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
v
},
|state| Rpo256::hash_elements(&state),
BatchSize::SmallInput,
)
});
}
fn rpx256_2to1(c: &mut Criterion) {
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
bench.iter(|| Rpx256::merge(black_box(&v)))
});
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
bench.iter_batched(
|| {
[
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
]
},
|state| Rpx256::merge(&state),
BatchSize::SmallInput,
)
});
}
fn rpx256_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
});
c.bench_function("RPX256 sequential hashing (random)", |bench| {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
v
},
|state| Rpx256::hash_elements(&state),
BatchSize::SmallInput,
)
});
}
fn blake3_2to1(c: &mut Criterion) {
let v: [<Blake3_256 as Hasher>::Digest; 2] =
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
c.bench_function("Blake3 2-to-1 hashing (cached)", |bench| {
bench.iter(|| Blake3_256::merge(black_box(&v)))
});
c.bench_function("Blake3 2-to-1 hashing (random)", |bench| {
bench.iter_batched(
|| {
[
Blake3_256::hash(&rand_value::<u64>().to_le_bytes()),
Blake3_256::hash(&rand_value::<u64>().to_le_bytes()),
]
},
|state| Blake3_256::merge(&state),
BatchSize::SmallInput,
)
});
}
fn blake3_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
c.bench_function("Blake3 sequential hashing (cached)", |bench| {
bench.iter(|| Blake3_256::hash_elements(black_box(&v)))
});
c.bench_function("Blake3 sequential hashing (random)", |bench| {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
v
},
|state| Blake3_256::hash_elements(&state),
BatchSize::SmallInput,
)
});
}
criterion_group!(
hash_group,
rpx256_2to1,
rpx256_sequential,
rpo256_2to1,
rpo256_sequential,
blake3_2to1,
blake3_sequential
);
criterion_main!(hash_group);

View file

@ -1,66 +0,0 @@
//! Benchmark for building a [`miden_crypto::merkle::MerkleTree`]. This is intended to be compared
//! with the results from `benches/smt-subtree.rs`, as building a fully balanced Merkle tree with
//! 256 leaves should indicate the *absolute best* performance we could *possibly* get for building
//! a depth-8 sparse Merkle subtree, though practically speaking building a fully balanced Merkle
//! tree will perform better than the sparse version. At the time of this writing (2024/11/24), this
//! benchmark is about four times more efficient than the equivalent benchmark in
//! `benches/smt-subtree.rs`.
use std::{hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE};
use rand_utils::prng_array;
fn balanced_merkle_even(c: &mut Criterion) {
c.bench_function("balanced-merkle-even", |b| {
b.iter_batched(
|| {
let entries: Vec<Word> =
(0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect();
assert_eq!(entries.len(), 256);
entries
},
|leaves| {
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
assert_eq!(tree.depth(), 8);
},
BatchSize::SmallInput,
);
});
}
fn balanced_merkle_rand(c: &mut Criterion) {
let mut seed = [0u8; 32];
c.bench_function("balanced-merkle-rand", |b| {
b.iter_batched(
|| {
let entries: Vec<Word> = (0..256).map(|_| generate_word(&mut seed)).collect();
assert_eq!(entries.len(), 256);
entries
},
|leaves| {
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
assert_eq!(tree.depth(), 8);
},
BatchSize::SmallInput,
);
});
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(20))
.configure_from_args();
targets = balanced_merkle_even, balanced_merkle_rand
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View file

@ -1,143 +0,0 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{build_subtree_for_bench, NodeIndex, SmtLeaf, SubtreeLeaf, SMT_DEPTH},
Felt, Word, ONE,
};
use rand_utils::prng_array;
use winter_utils::Randomizable;
const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256];
fn smt_subtree_even(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-even");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|n| {
// A single depth-8 subtree can have a maximum of 255 leaves.
let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
let key = RpoDigest::new([
generate_value(&mut seed),
ONE,
Felt::new(n),
Felt::new(leaf_index),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves.dedup_by_key(|leaf| leaf.col);
leaves
},
|leaves| {
// Benchmarked function.
let (subtree, _) = build_subtree_for_bench(
hint::black_box(leaves),
hint::black_box(SMT_DEPTH),
hint::black_box(SMT_DEPTH),
);
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
fn smt_subtree_random(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-rand");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let leaf_index: u8 = generate_value(&mut seed);
let key = RpoDigest::new([
ONE,
ONE,
Felt::new(i),
Felt::new(leaf_index as u64),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves.dedup_by_key(|leaf| leaf.col);
leaves
},
|leaves| {
let (subtree, _) = build_subtree_for_bench(
hint::black_box(leaves),
hint::black_box(SMT_DEPTH),
hint::black_box(SMT_DEPTH),
);
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(40))
.sample_size(60)
.configure_from_args();
targets = smt_subtree_even, smt_subtree_random
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View file

@ -1,71 +0,0 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
use rand_utils::prng_array;
use winter_utils::Randomizable;
// 2^0, 2^4, 2^8, 2^12, 2^16
const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576];
fn smt_with_entries(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("smt-with-entries");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
prepare_entries(pair_count, &mut seed)
},
|entries| {
// Benchmarked function.
Smt::with_entries(hint::black_box(entries)).unwrap();
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_with_entries_group;
config = Criterion::default()
//.measurement_time(Duration::from_secs(960))
.measurement_time(Duration::from_secs(60))
.sample_size(10)
.configure_from_args();
targets = smt_with_entries
}
criterion_main!(smt_with_entries_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn prepare_entries(pair_count: u64, seed: &mut [u8; 32]) -> Vec<(RpoDigest, [Felt; 4])> {
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let count = pair_count as f64;
let idx = ((i as f64 / count) * (count)) as u64;
let key = RpoDigest::new([generate_value(seed), ONE, Felt::new(i), Felt::new(idx)]);
let value = generate_word(seed);
(key, value)
})
.collect();
entries
}
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View file

@ -1,77 +0,0 @@
use core::mem::swap;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use miden_crypto::{
merkle::{LeafIndex, SimpleSmt},
Felt, Word,
};
use rand_utils::prng_array;
use seq_macro::seq;
fn smt_rpo(c: &mut Criterion) {
// setup trees
let mut seed = [0u8; 32];
let leaf = generate_word(&mut seed);
seq!(DEPTH in 14..=20 {
let leaves = ((1 << DEPTH) - 1) as u64;
for count in [1, leaves / 2, leaves] {
let entries: Vec<_> = (0..count)
.map(|i| {
let word = generate_word(&mut seed);
(i, word)
})
.collect();
let mut tree = SimpleSmt::<DEPTH>::with_leaves(entries).unwrap();
// benchmark 1
let mut insert = c.benchmark_group("smt update_leaf".to_string());
{
let depth = DEPTH;
let key = count >> 2;
insert.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&(key, leaf),
|b, (key, leaf)| {
b.iter(|| {
tree.insert(black_box(LeafIndex::<DEPTH>::new(*key).unwrap()), black_box(*leaf));
});
},
);
}
insert.finish();
// benchmark 2
let mut path = c.benchmark_group("smt get_leaf_path".to_string());
{
let depth = DEPTH;
let key = count >> 2;
path.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&key,
|b, key| {
b.iter(|| {
tree.open(black_box(&LeafIndex::<DEPTH>::new(*key).unwrap()));
});
},
);
}
path.finish();
}
});
}
criterion_group!(smt_group, smt_rpo);
criterion_main!(smt_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_word(seed: &mut [u8; 32]) -> Word {
swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View file

@ -1,487 +0,0 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{
DefaultMerkleStore as MerkleStore, LeafIndex, MerkleTree, NodeIndex, SimpleSmt,
SMT_MAX_DEPTH,
},
Felt, Word,
};
use rand_utils::{rand_array, rand_value};
/// Since MerkleTree can only be created when a power-of-two number of elements is used, the sample
/// sizes are limited to that.
static BATCH_SIZES: [usize; 3] = [2usize.pow(4), 2usize.pow(7), 2usize.pow(10)];
/// Generates a random `RpoDigest`.
fn random_rpo_digest() -> RpoDigest {
rand_array::<Felt, 4>().into()
}
/// Generates a random `Word`.
fn random_word() -> Word {
rand_array::<Felt, 4>()
}
/// Generates an index at the specified depth in `0..range`.
fn random_index(range: u64, depth: u8) -> NodeIndex {
let value = rand_value::<u64>() % range;
NodeIndex::new(depth, value).unwrap()
}
/// Benchmarks getting an empty leaf from the SMT and MerkleStore backends.
fn get_empty_leaf_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_empty_leaf_simplesmt");
const DEPTH: u8 = SMT_MAX_DEPTH;
let size = u64::MAX;
// both SMT and the store are pre-populated with empty hashes, accessing these values is what is
// being benchmarked here, so no values are inserted into the backends
let smt = SimpleSmt::<DEPTH>::new().unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
b.iter_batched(
|| random_index(size, DEPTH),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
b.iter_batched(
|| random_index(size, DEPTH),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
});
}
/// Benchmarks getting a leaf on Merkle trees and Merkle stores of varying power-of-two sizes.
fn get_leaf_merkletree(c: &mut Criterion) {
let mut group = c.benchmark_group("get_leaf_merkletree");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
let mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
let store = MerkleStore::from(&mtree);
let depth = mtree.depth();
let root = mtree.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|index| black_box(mtree.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks getting a leaf on SMT and Merkle stores of varying power-of-two sizes.
fn get_leaf_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_leaf_simplesmt");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let smt_leaves = leaves
.iter()
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks getting a node at half of the depth of an empty SMT and an empty Merkle store.
fn get_node_of_empty_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_node_of_empty_simplesmt");
const DEPTH: u8 = SMT_MAX_DEPTH;
// both SMT and the store are pre-populated with the empty hashes, accessing the internal nodes
// of these values is what is being benchmarked here, so no values are inserted into the
// backends.
let smt = SimpleSmt::<DEPTH>::new().unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let half_depth = DEPTH / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
});
}
/// Benchmarks getting a node at half of the depth of a Merkle tree and Merkle store of varying
/// power-of-two sizes.
fn get_node_merkletree(c: &mut Criterion) {
let mut group = c.benchmark_group("get_node_merkletree");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
let mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
let store = MerkleStore::from(&mtree);
let root = mtree.root();
let half_depth = mtree.depth() / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(mtree.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks getting a node at half the depth on SMT and Merkle stores of varying power-of-two
/// sizes.
fn get_node_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_node_simplesmt");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let smt_leaves = leaves
.iter()
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let half_depth = SMT_MAX_DEPTH / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks getting a path of a leaf on the Merkle tree and Merkle store backends.
fn get_leaf_path_merkletree(c: &mut Criterion) {
let mut group = c.benchmark_group("get_leaf_path_merkletree");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
let mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
let store = MerkleStore::from(&mtree);
let depth = mtree.depth();
let root = mtree.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|index| black_box(mtree.get_path(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|index| black_box(store.get_path(root, index)),
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks getting a path of a leaf on the SMT and Merkle store backends.
fn get_leaf_path_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_leaf_path_simplesmt");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let smt_leaves = leaves
.iter()
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| {
black_box(smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(index.value()).unwrap()))
},
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| black_box(store.get_path(root, index)),
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks creation of the different storage backends
fn new(c: &mut Criterion) {
let mut group = c.benchmark_group("new");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
// MerkleTree constructor is optimized to work with vectors. Create a new copy of the data
// and pass it to the benchmark function
group.bench_function(BenchmarkId::new("MerkleTree::new", size), |b| {
b.iter_batched(
|| leaves.iter().map(|v| v.into()).collect::<Vec<Word>>(),
|l| black_box(MerkleTree::new(l)),
BatchSize::SmallInput,
)
});
// This could be done with `bench_with_input`, however to remove variables while comparing
// with MerkleTree it is using `iter_batched`
group.bench_function(BenchmarkId::new("MerkleStore::extend::MerkleTree", size), |b| {
b.iter_batched(
|| leaves.iter().map(|v| v.into()).collect::<Vec<Word>>(),
|l| {
let mtree = MerkleTree::new(l).unwrap();
black_box(MerkleStore::from(&mtree));
},
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("SimpleSmt::new", size), |b| {
b.iter_batched(
|| {
leaves
.iter()
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>()
},
|l| black_box(SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore::extend::SimpleSmt", size), |b| {
b.iter_batched(
|| {
leaves
.iter()
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>()
},
|l| {
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l).unwrap();
black_box(MerkleStore::from(&smt));
},
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks updating a leaf on MerkleTree and MerkleStore backends.
fn update_leaf_merkletree(c: &mut Criterion) {
let mut group = c.benchmark_group("update_leaf_merkletree");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
let mut mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
let mut store = MerkleStore::from(&mtree);
let depth = mtree.depth();
let root = mtree.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
b.iter_batched(
|| (rand_value::<u64>() % size_u64, random_word()),
|(index, value)| black_box(mtree.update_leaf(index, value)),
BatchSize::SmallInput,
)
});
let mut store_root = root;
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| (random_index(size_u64, depth), random_word()),
|(index, value)| {
// The MerkleTree automatically updates its internal root, the Store maintains
// the old root and adds the new one. Here we update the root to have a fair
// comparison
store_root = store.set_node(root, index, value.into()).unwrap().root;
black_box(store_root)
},
BatchSize::SmallInput,
)
});
}
}
/// Benchmarks updating a leaf on SMT and MerkleStore backends.
fn update_leaf_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("update_leaf_simplesmt");
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
for size in BATCH_SIZES {
let leaves = &random_data[..size];
let smt_leaves = leaves
.iter()
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let mut smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let mut store = MerkleStore::from(&smt);
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSMT", size), |b| {
b.iter_batched(
|| (rand_value::<u64>() % size_u64, random_word()),
|(index, value)| {
black_box(smt.insert(LeafIndex::<SMT_MAX_DEPTH>::new(index).unwrap(), value))
},
BatchSize::SmallInput,
)
});
let mut store_root = root;
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| (random_index(size_u64, SMT_MAX_DEPTH), random_word()),
|(index, value)| {
// The MerkleTree automatically updates its internal root, the Store maintains
// the old root and adds the new one. Here we update the root to have a fair
// comparison
store_root = store.set_node(root, index, value.into()).unwrap().root;
black_box(store_root)
},
BatchSize::SmallInput,
)
});
}
}
criterion_group!(
store_group,
get_empty_leaf_simplesmt,
get_leaf_merkletree,
get_leaf_path_merkletree,
get_leaf_path_simplesmt,
get_leaf_simplesmt,
get_node_merkletree,
get_node_of_empty_simplesmt,
get_node_simplesmt,
new,
update_leaf_merkletree,
update_leaf_simplesmt,
);
criterion_main!(store_group);

View file

@ -1,19 +0,0 @@
fn main() {
#[cfg(target_feature = "sve")]
compile_arch_arm64_sve();
}
#[cfg(target_feature = "sve")]
fn compile_arch_arm64_sve() {
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.c");
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.h");
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/rpo_hash.h");
cc::Build::new()
.file(format!("{RPO_SVE_PATH}/library.c"))
.flag("-march=armv8-a+sve")
.flag("-O3")
.compile("rpo_sve");
}

View file

@ -1,5 +0,0 @@
[toolchain]
channel = "1.84"
components = ["rustfmt", "rust-src", "clippy"]
targets = ["wasm32-unknown-unknown"]
profile = "minimal"

View file

@ -1,23 +0,0 @@
edition = "2021"
array_width = 80
attr_fn_like_width = 80
chain_width = 80
comment_width = 100
condense_wildcard_suffixes = true
fn_call_width = 80
format_code_in_doc_comments = true
format_macro_matchers = true
group_imports = "StdExternalCrate"
hex_literal_case = "Lower"
imports_granularity = "Crate"
match_block_trailing_comma = true
newline_style = "Unix"
reorder_imports = true
reorder_modules = true
single_line_if_else_max_width = 60
single_line_let_else_max_width = 60
struct_lit_width = 40
struct_variant_width = 40
use_field_init_shorthand = true
use_try_shorthand = true
wrap_comments = true

View file

@ -1,21 +0,0 @@
#!/bin/bash
set -uo pipefail
CHANGELOG_FILE="${1:-CHANGELOG.md}"
if [ "${NO_CHANGELOG_LABEL}" = "true" ]; then
# 'no changelog' set, so finish successfully
echo "\"no changelog\" label has been set"
exit 0
else
# a changelog check is required
# fail if the diff is empty
if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then
>&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior
can be overridden by using the \"no changelog\" label, which is used for changes
that are trivial / explicitly stated not to require a changelog entry."
exit 1
fi
echo "The \"CHANGELOG.md\" file has been updated."
fi

View file

@ -1,15 +0,0 @@
#!/bin/bash
# Get rust-toolchain.toml file channel
TOOLCHAIN_VERSION=$(grep 'channel' rust-toolchain.toml | sed -E 's/.*"(.*)".*/\1/')
# Get workspace Cargo.toml file rust-version
CARGO_VERSION=$(grep 'rust-version' Cargo.toml | sed -E 's/.*"(.*)".*/\1/')
# Check version match
if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then
echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION"
exit 1
fi
echo "Rust versions match ✅"

1
series Submodule

@ -0,0 +1 @@
Subproject commit fa0943fc4864a76c98516177e9f7d781d35a57e6

View file

@ -1,3 +0,0 @@
//! Digital signature schemes supported by default in the Miden VM.
pub mod rpo_falcon512;

View file

@ -1,70 +0,0 @@
use alloc::vec::Vec;
use num::Zero;
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
// HASH-TO-POINT FUNCTIONS
// ================================================================================================
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
/// nonce using RPO256.
pub fn hash_to_point_rpo256(message: Word, nonce: &Nonce) -> Polynomial<FalconFelt> {
let mut state = [ZERO; Rpo256::STATE_WIDTH];
// absorb the nonce into the state
let nonce_elements = nonce.to_elements();
for (&n, s) in nonce_elements.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
*s = n;
}
Rpo256::apply_permutation(&mut state);
// absorb message into the state
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
*s = m;
}
// squeeze the coefficients of the polynomial
let mut i = 0;
let mut res = [FalconFelt::zero(); N];
for _ in 0..64 {
Rpo256::apply_permutation(&mut state);
for a in &state[Rpo256::RATE_RANGE] {
res[i] = FalconFelt::new((a.as_int() % MODULUS as u64) as i16);
i += 1;
}
}
Polynomial::new(res.to_vec())
}
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
/// nonce using SHAKE256. This is the hash-to-point algorithm used in the reference implementation.
#[allow(dead_code)]
pub fn hash_to_point_shake256(message: &[u8], nonce: &Nonce) -> Polynomial<FalconFelt> {
use sha3::{
digest::{ExtendableOutput, Update, XofReader},
Shake256,
};
let mut data = vec![];
data.extend_from_slice(nonce.as_bytes());
data.extend_from_slice(message);
const K: u32 = (1u32 << 16) / MODULUS as u32;
let mut hasher = Shake256::default();
hasher.update(&data);
let mut reader = hasher.finalize_xof();
let mut coefficients: Vec<FalconFelt> = Vec::with_capacity(N);
while coefficients.len() != N {
let mut randomness = [0u8; 2];
reader.read(&mut randomness);
let t = ((randomness[0] as u32) << 8) | (randomness[1] as u32);
if t < K * MODULUS as u32 {
coefficients.push(FalconFelt::new((t % MODULUS as u32) as i16));
}
}
Polynomial { coefficients }
}

View file

@ -1,55 +0,0 @@
use super::{
math::{FalconFelt, Polynomial},
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Serializable, Signature,
Word,
};
mod public_key;
pub use public_key::{PubKeyPoly, PublicKey};
mod secret_key;
pub use secret_key::SecretKey;
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use winter_math::FieldElement;
use winter_utils::{Deserializable, Serializable};
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
#[test]
fn test_falcon_verification() {
let seed = [0_u8; 32];
let mut rng = ChaCha20Rng::from_seed(seed);
// generate random keys
let sk = SecretKey::with_rng(&mut rng);
let pk = sk.public_key();
// test secret key serialization/deserialization
let mut buffer = vec![];
sk.write_into(&mut buffer);
let sk_deserialized = SecretKey::read_from_bytes(&buffer).unwrap();
assert_eq!(sk.short_lattice_basis(), sk_deserialized.short_lattice_basis());
// sign a random message
let message: Word = [ONE; 4];
let signature = sk.sign_with_rng(message, &mut rng);
// make sure the signature verifies correctly
assert!(pk.verify(message, &signature));
// a signature should not verify against a wrong message
let message2: Word = [ONE.double(); 4];
assert!(!pk.verify(message2, &signature));
// a signature should not verify against a wrong public key
let sk2 = SecretKey::with_rng(&mut rng);
assert!(!sk2.public_key().verify(message, &signature))
}
}

View file

@ -1,139 +0,0 @@
use alloc::string::ToString;
use core::ops::Deref;
use num::Zero;
use super::{
super::{Rpo256, LOG_N, N, PK_LEN},
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial,
Serializable, Signature, Word,
};
use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS;
// PUBLIC KEY
// ================================================================================================
/// A public key for verifying signatures.
///
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
/// the polynomial representing the raw bytes of the expanded public key. The hash is computed
/// using Rpo256.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PublicKey(Word);
impl PublicKey {
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
pub fn new(pub_key: Word) -> Self {
Self(pub_key)
}
/// Verifies the provided signature against provided message and this public key.
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
signature.verify(message, self.0)
}
}
impl From<PubKeyPoly> for PublicKey {
fn from(pk_poly: PubKeyPoly) -> Self {
let pk_felts: Polynomial<Felt> = pk_poly.0.into();
let pk_digest = Rpo256::hash_elements(&pk_felts.coefficients).into();
Self(pk_digest)
}
}
impl From<PublicKey> for Word {
fn from(key: PublicKey) -> Self {
key.0
}
}
// PUBLIC KEY POLYNOMIAL
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PubKeyPoly(pub Polynomial<FalconFelt>);
impl Deref for PubKeyPoly {
type Target = Polynomial<FalconFelt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Polynomial<FalconFelt>> for PubKeyPoly {
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
Self(pk_poly)
}
}
impl Serializable for &PubKeyPoly {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let mut buf = [0_u8; PK_LEN];
buf[0] = LOG_N;
let mut acc = 0_u32;
let mut acc_len: u32 = 0;
let mut input_pos = 1;
for c in self.0.coefficients.iter() {
let c = c.value();
acc = (acc << FALCON_ENCODING_BITS) | c as u32;
acc_len += FALCON_ENCODING_BITS;
while acc_len >= 8 {
acc_len -= 8;
buf[input_pos] = (acc >> acc_len) as u8;
input_pos += 1;
}
}
if acc_len > 0 {
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
}
target.write(buf);
}
}
impl Deserializable for PubKeyPoly {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let buf = source.read_array::<PK_LEN>()?;
if buf[0] != LOG_N {
return Err(DeserializationError::InvalidValue(format!(
"Failed to decode public key: expected the first byte to be {LOG_N} but was {}",
buf[0]
)));
}
let mut acc = 0_u32;
let mut acc_len = 0;
let mut output = [FalconFelt::zero(); N];
let mut output_idx = 0;
for &byte in buf.iter().skip(1) {
acc = (acc << 8) | (byte as u32);
acc_len += 8;
if acc_len >= FALCON_ENCODING_BITS {
acc_len -= FALCON_ENCODING_BITS;
let w = (acc >> acc_len) & 0x3fff;
let element = w.try_into().map_err(|err| {
DeserializationError::InvalidValue(format!(
"Failed to decode public key: {err}"
))
})?;
output[output_idx] = element;
output_idx += 1;
}
}
if (acc & ((1u32 << acc_len) - 1)) == 0 {
Ok(Polynomial::new(output.to_vec()).into())
} else {
Err(DeserializationError::InvalidValue(
"Failed to decode public key: input not fully consumed".to_string(),
))
}
}
}

View file

@ -1,401 +0,0 @@
use alloc::{string::ToString, vec::Vec};
use num::Complex;
#[cfg(not(feature = "std"))]
use num::Float;
use num_complex::Complex64;
use rand::Rng;
use super::{
super::{
math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial},
signature::SignaturePoly,
ByteReader, ByteWriter, Deserializable, DeserializationError, Nonce, Serializable,
ShortLatticeBasis, Signature, Word, MODULUS, N, SIGMA, SIG_L2_BOUND,
},
PubKeyPoly, PublicKey,
};
use crate::dsa::rpo_falcon512::{
hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN,
};
// CONSTANTS
// ================================================================================================
const WIDTH_BIG_POLY_COEFFICIENT: usize = 8;
const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
// SECRET KEY
// ================================================================================================
/// Represents the secret key for Falcon DSA.
///
/// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each
/// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo
/// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has
/// the property of being short with respect to a certain norm and an upper bound appropriate for
/// a given security parameter. The public key on the other hand is another basis for the same
/// lattice and can be described by a single polynomial h with integer coefficients modulo ϕ.
/// The two keys are related by the following relation:
///
/// 1. h = g /f [mod ϕ][mod p]
/// 2. f.G - g.F = p [mod ϕ]
///
/// where p = 12289 is the Falcon prime. Equation 2 is called the NTRU equation.
/// The secret key is generated by first sampling a random pair (f, g) of polynomials using
/// an appropriate distribution that yields short but not too short polynomials with integer
/// coefficients modulo ϕ. The NTRU equation is then used to find a matching pair (F, G).
/// The public key is then derived from the secret key using equation 1.
///
/// To allow for fast signature generation, the secret key is pre-processed into a more suitable
/// form, called the LDL tree, and this allows for fast sampling of short vectors in the lattice
/// using Fast Fourier sampling during signature generation (ffSampling algorithm 11 in [1]).
///
/// [1]: https://falcon-sign.info/falcon.pdf
#[derive(Debug, Clone)]
pub struct SecretKey {
secret_key: ShortLatticeBasis,
tree: LdlTree,
}
#[allow(clippy::new_without_default)]
impl SecretKey {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Generates a secret key from OS-provided randomness.
#[cfg(feature = "std")]
pub fn new() -> Self {
use rand::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::from_entropy();
Self::with_rng(&mut rng)
}
/// Generates a secret_key using the provided random number generator `Rng`.
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
let basis = ntru_gen(N, rng);
Self::from_short_lattice_basis(basis)
}
/// Given a short basis [[g, -f], [G, -F]], computes the normalized LDL tree i.e., Falcon tree.
fn from_short_lattice_basis(basis: ShortLatticeBasis) -> SecretKey {
// FFT each polynomial of the short basis.
let basis_fft = to_complex_fft(&basis);
// compute the Gram matrix.
let gram_fft = gram(basis_fft);
// construct the LDL tree of the Gram matrix.
let mut tree = ffldl(gram_fft);
// normalize the leaves of the LDL tree.
normalize_tree(&mut tree, SIGMA);
Self { secret_key: basis, tree }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the polynomials of the short lattice basis of this secret key.
pub fn short_lattice_basis(&self) -> &ShortLatticeBasis {
&self.secret_key
}
/// Returns the public key corresponding to this secret key.
pub fn public_key(&self) -> PublicKey {
self.compute_pub_key_poly().into()
}
/// Returns the LDL tree associated to this secret key.
pub fn tree(&self) -> &LdlTree {
&self.tree
}
// SIGNATURE GENERATION
// --------------------------------------------------------------------------------------------
/// Signs a message with this secret key.
#[cfg(feature = "std")]
pub fn sign(&self, message: Word) -> Signature {
use rand::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::from_entropy();
self.sign_with_rng(message, &mut rng)
}
/// Signs a message with the secret key relying on the provided randomness generator.
pub fn sign_with_rng<R: Rng>(&self, message: Word, rng: &mut R) -> Signature {
let mut nonce_bytes = [0u8; SIG_NONCE_LEN];
rng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::new(nonce_bytes);
let h = self.compute_pub_key_poly();
let c = hash_to_point_rpo256(message, &nonce);
let s2 = self.sign_helper(c, rng);
Signature::new(nonce, h, s2)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Derives the public key corresponding to this secret key using h = g /f [mod ϕ][mod p].
pub fn compute_pub_key_poly(&self) -> PubKeyPoly {
let g: Polynomial<FalconFelt> = self.secret_key[0].clone().into();
let g_fft = g.fft();
let minus_f: Polynomial<FalconFelt> = self.secret_key[1].clone().into();
let f = -minus_f;
let f_fft = f.fft();
let h_fft = g_fft.hadamard_div(&f_fft);
h_fft.ifft().into()
}
/// Signs a message polynomial with the secret key.
///
/// Takes a randomness generator implementing `Rng` and message polynomial representing `c`
/// the hash-to-point of the message to be signed. It outputs a signature polynomial `s2`.
fn sign_helper<R: Rng>(&self, c: Polynomial<FalconFelt>, rng: &mut R) -> SignaturePoly {
let one_over_q = 1.0 / (MODULUS as f64);
let c_over_q_fft = c.map(|cc| Complex::new(one_over_q * cc.value() as f64, 0.0)).fft();
// B = [[FFT(g), -FFT(f)], [FFT(G), -FFT(F)]]
let [g_fft, minus_f_fft, big_g_fft, minus_big_f_fft] = to_complex_fft(&self.secret_key);
let t0 = c_over_q_fft.hadamard_mul(&minus_big_f_fft);
let t1 = -c_over_q_fft.hadamard_mul(&minus_f_fft);
loop {
let bold_s = loop {
let z = ffsampling(&(t0.clone(), t1.clone()), &self.tree, rng);
let t0_min_z0 = t0.clone() - z.0;
let t1_min_z1 = t1.clone() - z.1;
// s = (t-z) * B
let s0 = t0_min_z0.hadamard_mul(&g_fft) + t1_min_z1.hadamard_mul(&big_g_fft);
let s1 =
t0_min_z0.hadamard_mul(&minus_f_fft) + t1_min_z1.hadamard_mul(&minus_big_f_fft);
// compute the norm of (s0||s1) and note that they are in FFT representation
let length_squared: f64 =
(s0.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>()
+ s1.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>())
/ (N as f64);
if length_squared > (SIG_L2_BOUND as f64) {
continue;
}
break [-s0, s1];
};
let s2 = bold_s[1].ifft();
let s2_coef: [i16; N] = s2
.coefficients
.iter()
.map(|a| a.re.round() as i16)
.collect::<Vec<i16>>()
.try_into()
.expect("The number of coefficients should be equal to N");
if let Ok(s2) = SignaturePoly::try_from(&s2_coef) {
return s2;
}
}
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for SecretKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let basis = &self.secret_key;
// header
let n = basis[0].coefficients.len();
let l = n.checked_ilog2().unwrap() as u8;
let header: u8 = (5 << 4) | l;
let neg_f = &basis[1];
let g = &basis[0];
let neg_big_f = &basis[3];
let mut buffer = Vec::with_capacity(1281);
buffer.push(header);
let f_i8: Vec<i8> = neg_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&f_i8_encoded);
let g_i8: Vec<i8> = g
.coefficients
.iter()
.map(|&a| FalconFelt::new(a).balanced_value() as i8)
.collect();
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&g_i8_encoded);
let big_f_i8: Vec<i8> = neg_big_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&big_f_i8_encoded);
target.write_bytes(&buffer);
}
}
impl Deserializable for SecretKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let byte_vector: [u8; SK_LEN] = source.read_array()?;
// check length
if byte_vector.len() < 2 {
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
}
// read fields
let header = byte_vector[0];
// check fixed bits in header
if (header >> 4) != 5 {
return Err(DeserializationError::InvalidValue("Invalid header format".to_string()));
}
// check log n
let logn = (header & 15) as usize;
let n = 1 << logn;
// match against const variant generic parameter
if n != N {
return Err(DeserializationError::InvalidValue(
"Unsupported Falcon DSA variant".to_string(),
));
}
if byte_vector.len() != SK_LEN {
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
}
let chunk_size_f = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
let chunk_size_g = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
let chunk_size_big_f = ((n * WIDTH_BIG_POLY_COEFFICIENT) + 7) >> 3;
let f = decode_i8(&byte_vector[1..chunk_size_f + 1], WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
let g = decode_i8(
&byte_vector[chunk_size_f + 1..(chunk_size_f + chunk_size_g + 1)],
WIDTH_SMALL_POLY_COEFFICIENT,
)
.unwrap();
let big_f = decode_i8(
&byte_vector[(chunk_size_f + chunk_size_g + 1)
..(chunk_size_f + chunk_size_g + chunk_size_big_f + 1)],
WIDTH_BIG_POLY_COEFFICIENT,
)
.unwrap();
let f = Polynomial::new(f.iter().map(|&c| FalconFelt::new(c.into())).collect());
let g = Polynomial::new(g.iter().map(|&c| FalconFelt::new(c.into())).collect());
let big_f = Polynomial::new(big_f.iter().map(|&c| FalconFelt::new(c.into())).collect());
// big_g * f - g * big_f = p (mod X^n + 1)
let big_g = g.fft().hadamard_div(&f.fft()).hadamard_mul(&big_f.fft()).ifft();
let basis = [
g.map(|f| f.balanced_value()),
-f.map(|f| f.balanced_value()),
big_g.map(|f| f.balanced_value()),
-big_f.map(|f| f.balanced_value()),
];
Ok(Self::from_short_lattice_basis(basis))
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Computes the complex FFT of the secret key polynomials.
fn to_complex_fft(basis: &[Polynomial<i16>; 4]) -> [Polynomial<Complex<f64>>; 4] {
let [g, f, big_g, big_f] = basis.clone();
let g_fft = g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
let minus_f_fft = f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
let big_g_fft = big_g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
let minus_big_f_fft = big_f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
[g_fft, minus_f_fft, big_g_fft, minus_big_f_fft]
}
/// Encodes a sequence of signed integers such that each integer x satisfies |x| < 2^(bits-1)
/// for a given parameter bits. bits can take either the value 6 or 8.
pub fn encode_i8(x: &[i8], bits: usize) -> Option<Vec<u8>> {
let maxv = (1 << (bits - 1)) - 1_usize;
let maxv = maxv as i8;
let minv = -maxv;
for &c in x {
if c > maxv || c < minv {
return None;
}
}
let out_len = ((N * bits) + 7) >> 3;
let mut buf = vec![0_u8; out_len];
let mut acc = 0_u32;
let mut acc_len = 0;
let mask = ((1_u16 << bits) - 1) as u8;
let mut input_pos = 0;
for &c in x {
acc = (acc << bits) | (c as u8 & mask) as u32;
acc_len += bits;
while acc_len >= 8 {
acc_len -= 8;
buf[input_pos] = (acc >> acc_len) as u8;
input_pos += 1;
}
}
if acc_len > 0 {
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
}
Some(buf)
}
/// Decodes a sequence of bytes into a sequence of signed integers such that each integer x
/// satisfies |x| < 2^(bits-1) for a given parameter bits. bits can take either the value 6 or 8.
pub fn decode_i8(buf: &[u8], bits: usize) -> Option<Vec<i8>> {
let mut x = [0_i8; N];
let mut i = 0;
let mut j = 0;
let mut acc = 0_u32;
let mut acc_len = 0;
let mask = (1_u32 << bits) - 1;
let a = (1 << bits) as u8;
let b = ((1 << (bits - 1)) - 1) as u8;
while i < N {
acc = (acc << 8) | (buf[j] as u32);
j += 1;
acc_len += 8;
while acc_len >= bits && i < N {
acc_len -= bits;
let w = (acc >> acc_len) & mask;
let w = w as u8;
let z = if w > b { w as i8 - a as i8 } else { w as i8 };
x[i] = z;
i += 1;
}
}
if (acc & ((1u32 << acc_len) - 1)) == 0 {
Some(x.to_vec())
} else {
None
}
}

View file

@ -1,124 +0,0 @@
use alloc::boxed::Box;
#[cfg(not(feature = "std"))]
use num::Float;
use num::{One, Zero};
use num_complex::{Complex, Complex64};
use rand::Rng;
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
const SIGMIN: f64 = 1.2778336969128337;
/// Computes the Gram matrix. The argument must be a 2x2 matrix
/// whose elements are equal-length vectors of complex numbers,
/// representing polynomials in FFT domain.
pub fn gram(b: [Polynomial<Complex64>; 4]) -> [Polynomial<Complex64>; 4] {
const N: usize = 2;
let mut g: [Polynomial<Complex<f64>>; 4] =
[Polynomial::zero(), Polynomial::zero(), Polynomial::zero(), Polynomial::zero()];
for i in 0..N {
for j in 0..N {
for k in 0..N {
g[N * i + j] = g[N * i + j].clone()
+ b[N * i + k].hadamard_mul(&b[N * j + k].map(|c| c.conj()));
}
}
}
g
}
/// Computes the LDL decomposition of a 2x2 matrix G such that
/// L D L* = G
/// where D is diagonal, and L is lower-triangular. The elements of the matrices are in FFT domain.
pub fn ldl(
g: [Polynomial<Complex64>; 4],
) -> ([Polynomial<Complex64>; 4], [Polynomial<Complex64>; 4]) {
let zero = Polynomial::<Complex64>::one();
let one = Polynomial::<Complex64>::zero();
let l10 = g[2].hadamard_div(&g[0]);
let bc = l10.map(|c| c * c.conj());
let abc = g[0].hadamard_mul(&bc);
let d11 = g[3].clone() - abc;
let l = [one.clone(), zero.clone(), l10.clone(), one];
let d = [g[0].clone(), zero.clone(), zero, d11];
(l, d)
}
#[derive(Debug, Clone)]
pub enum LdlTree {
Branch(Polynomial<Complex64>, Box<LdlTree>, Box<LdlTree>),
Leaf([Complex64; 2]),
}
/// Computes the LDL Tree of G. Corresponds to Algorithm 9 of the specification [1, p.37].
/// The argument is a 2x2 matrix of polynomials, given in FFT form.
/// [1]: https://falcon-sign.info/falcon.pdf
pub fn ffldl(gram_matrix: [Polynomial<Complex64>; 4]) -> LdlTree {
let n = gram_matrix[0].coefficients.len();
let (l, d) = ldl(gram_matrix);
if n > 2 {
let (d00, d01) = d[0].split_fft();
let (d10, d11) = d[3].split_fft();
let g0 = [d00.clone(), d01.clone(), d01.map(|c| c.conj()), d00];
let g1 = [d10.clone(), d11.clone(), d11.map(|c| c.conj()), d10];
LdlTree::Branch(l[2].clone(), Box::new(ffldl(g0)), Box::new(ffldl(g1)))
} else {
LdlTree::Branch(
l[2].clone(),
Box::new(LdlTree::Leaf(d[0].clone().coefficients.try_into().unwrap())),
Box::new(LdlTree::Leaf(d[3].clone().coefficients.try_into().unwrap())),
)
}
}
/// Normalizes the leaves of an LDL tree using a given normalization value `sigma`.
pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
match tree {
LdlTree::Branch(_ell, left, right) => {
normalize_tree(left, sigma);
normalize_tree(right, sigma);
},
LdlTree::Leaf(vector) => {
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
vector[1] = Complex64::zero();
},
}
}
/// Samples short polynomials using a Falcon tree. Algorithm 11 from the spec [1, p.40].
///
/// [1]: https://falcon-sign.info/falcon.pdf
pub fn ffsampling<R: Rng>(
t: &(Polynomial<Complex64>, Polynomial<Complex64>),
tree: &LdlTree,
mut rng: &mut R,
) -> (Polynomial<Complex64>, Polynomial<Complex64>) {
match tree {
LdlTree::Branch(ell, left, right) => {
let bold_t1 = t.1.split_fft();
let bold_z1 = ffsampling(&bold_t1, right, rng);
let z1 = Polynomial::<Complex64>::merge_fft(&bold_z1.0, &bold_z1.1);
// t0' = t0 + (t1 - z1) * l
let t0_prime = t.0.clone() + (t.1.clone() - z1.clone()).hadamard_mul(ell);
let bold_t0 = t0_prime.split_fft();
let bold_z0 = ffsampling(&bold_t0, left, rng);
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
(z0, z1)
},
LdlTree::Leaf(value) => {
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
(
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
)
},
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,174 +0,0 @@
use alloc::string::String;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::{One, Zero};
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FalconFelt(u32);
impl FalconFelt {
pub const fn new(value: i16) -> Self {
let gtz_bool = value >= 0;
let gtz_int = gtz_bool as i16;
let gtz_sign = gtz_int - ((!gtz_bool) as i16);
let reduced = gtz_sign * (gtz_sign * value) % MODULUS;
let canonical_representative = (reduced + MODULUS * (1 - gtz_int)) as u32;
FalconFelt(canonical_representative)
}
pub const fn value(&self) -> i16 {
self.0 as i16
}
pub fn balanced_value(&self) -> i16 {
let value = self.value();
let g = (value > ((MODULUS) / 2)) as i16;
value - (MODULUS) * g
}
pub const fn multiply(&self, other: Self) -> Self {
FalconFelt((self.0 * other.0) % MODULUS as u32)
}
}
impl Add for FalconFelt {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn add(self, rhs: Self) -> Self::Output {
let (s, _) = self.0.overflowing_add(rhs.0);
let (d, n) = s.overflowing_sub(MODULUS as u32);
let (r, _) = d.overflowing_add(MODULUS as u32 * (n as u32));
FalconFelt(r)
}
}
impl AddAssign for FalconFelt {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Sub for FalconFelt {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self + -rhs
}
}
impl SubAssign for FalconFelt {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Neg for FalconFelt {
type Output = FalconFelt;
fn neg(self) -> Self::Output {
let is_nonzero = self.0 != 0;
let r = MODULUS as u32 - self.0;
FalconFelt(r * (is_nonzero as u32))
}
}
impl Mul for FalconFelt {
fn mul(self, rhs: Self) -> Self::Output {
FalconFelt((self.0 * rhs.0) % MODULUS as u32)
}
type Output = Self;
}
impl MulAssign for FalconFelt {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Div for FalconFelt {
type Output = FalconFelt;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inverse_or_zero()
}
}
impl DivAssign for FalconFelt {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs
}
}
impl Zero for FalconFelt {
fn zero() -> Self {
FalconFelt::new(0)
}
fn is_zero(&self) -> bool {
self.0 == 0
}
}
impl One for FalconFelt {
fn one() -> Self {
FalconFelt::new(1)
}
}
impl Inverse for FalconFelt {
fn inverse_or_zero(self) -> Self {
// q-2 = 0b10 11 11 11 11 11 11
let two = self.multiply(self);
let three = two.multiply(self);
let six = three.multiply(three);
let twelve = six.multiply(six);
let fifteen = twelve.multiply(three);
let thirty = fifteen.multiply(fifteen);
let sixty = thirty.multiply(thirty);
let sixty_three = sixty.multiply(three);
let sixty_three_sq = sixty_three.multiply(sixty_three);
let sixty_three_qu = sixty_three_sq.multiply(sixty_three_sq);
let sixty_three_oc = sixty_three_qu.multiply(sixty_three_qu);
let sixty_three_hx = sixty_three_oc.multiply(sixty_three_oc);
let sixty_three_tt = sixty_three_hx.multiply(sixty_three_hx);
let sixty_three_sf = sixty_three_tt.multiply(sixty_three_tt);
let all_ones = sixty_three_sf.multiply(sixty_three);
let two_e_twelve = all_ones.multiply(self);
let two_e_thirteen = two_e_twelve.multiply(two_e_twelve);
two_e_thirteen.multiply(all_ones)
}
}
impl CyclotomicFourier for FalconFelt {
fn primitive_root_of_unity(n: usize) -> Self {
let log2n = n.ilog2();
assert!(log2n <= 12);
// and 1331 is a twelfth root of unity
let mut a = FalconFelt::new(1331);
let num_squarings = 12 - n.ilog2();
for _ in 0..num_squarings {
a *= a;
}
a
}
}
impl TryFrom<u32> for FalconFelt {
type Error = String;
fn try_from(value: u32) -> Result<Self, Self::Error> {
if value >= MODULUS as u32 {
Err(format!("value {value} is greater than or equal to the field modulus {MODULUS}"))
} else {
Ok(FalconFelt::new(value as i16))
}
}
}

View file

@ -1,322 +0,0 @@
//! Contains different structs and methods related to the Falcon DSA.
//!
//! It uses and acknowledges the work in:
//!
//! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas
//! Pornin.
//! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec.
use alloc::{string::String, vec::Vec};
use core::ops::MulAssign;
#[cfg(not(feature = "std"))]
use num::Float;
use num::{BigInt, FromPrimitive, One, Zero};
use num_complex::Complex64;
use rand::Rng;
use super::MODULUS;
mod fft;
pub use fft::{CyclotomicFourier, FastFft};
mod field;
pub use field::FalconFelt;
mod ffsampling;
pub use ffsampling::{ffldl, ffsampling, gram, normalize_tree, LdlTree};
mod samplerz;
use self::samplerz::sampler_z;
mod polynomial;
pub use polynomial::Polynomial;
pub trait Inverse: Copy + Zero + MulAssign + One {
/// Gets the inverse of a, or zero if it is zero.
fn inverse_or_zero(self) -> Self;
/// Gets the inverses of a batch of elements, and skip over any that are zero.
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
let mut acc = Self::one();
let mut rp: Vec<Self> = Vec::with_capacity(batch.len());
for batch_item in batch {
if !batch_item.is_zero() {
rp.push(acc);
acc = *batch_item * acc;
} else {
rp.push(Self::zero());
}
}
let mut inv = Self::inverse_or_zero(acc);
for i in (0..batch.len()).rev() {
if !batch[i].is_zero() {
rp[i] *= inv;
inv *= batch[i];
}
}
rp
}
}
impl Inverse for Complex64 {
fn inverse_or_zero(self) -> Self {
let modulus = self.re * self.re + self.im * self.im;
Complex64::new(self.re / modulus, -self.im / modulus)
}
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
batch.iter().map(|&c| Complex64::new(1.0, 0.0) / c).collect()
}
}
impl Inverse for f64 {
fn inverse_or_zero(self) -> Self {
1.0 / self
}
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
batch.iter().map(|&c| 1.0 / c).collect()
}
}
/// Samples 4 small polynomials f, g, F, G such that f * G - g * F = q mod (X^n + 1).
/// Algorithm 5 (NTRUgen) of the documentation [1, p.34].
///
/// [1]: https://falcon-sign.info/falcon.pdf
pub(crate) fn ntru_gen<R: Rng>(n: usize, rng: &mut R) -> [Polynomial<i16>; 4] {
loop {
let f = gen_poly(n, rng);
let g = gen_poly(n, rng);
let f_ntt = f.map(|&i| FalconFelt::new(i)).fft();
if f_ntt.coefficients.iter().any(|e| e.is_zero()) {
continue;
}
let gamma = gram_schmidt_norm_squared(&f, &g);
if gamma > 1.3689f64 * (MODULUS as f64) {
continue;
}
if let Some((capital_f, capital_g)) =
ntru_solve(&f.map(|&i| i.into()), &g.map(|&i| i.into()))
{
return [
g,
-f,
capital_g.map(|i| i.try_into().unwrap()),
-capital_f.map(|i| i.try_into().unwrap()),
];
}
}
}
/// Solves the NTRU equation. Given f, g in ZZ[X], find F, G in ZZ[X] such that:
///
/// f G - g F = q mod (X^n + 1)
///
/// Algorithm 6 of the specification [1, p.35].
///
/// [1]: https://falcon-sign.info/falcon.pdf
fn ntru_solve(
f: &Polynomial<BigInt>,
g: &Polynomial<BigInt>,
) -> Option<(Polynomial<BigInt>, Polynomial<BigInt>)> {
let n = f.coefficients.len();
if n == 1 {
let (gcd, u, v) = xgcd(&f.coefficients[0], &g.coefficients[0]);
if gcd != BigInt::one() {
return None;
}
return Some((
(Polynomial::new(vec![-v * BigInt::from_u32(MODULUS as u32).unwrap()])),
Polynomial::new(vec![u * BigInt::from_u32(MODULUS as u32).unwrap()]),
));
}
let f_prime = f.field_norm();
let g_prime = g.field_norm();
let (capital_f_prime, capital_g_prime) = ntru_solve(&f_prime, &g_prime)?;
let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
let f_minx = f.galois_adjoint();
let g_minx = g.galois_adjoint();
let mut capital_f = (capital_f_prime_xsq.karatsuba(&g_minx)).reduce_by_cyclotomic(n);
let mut capital_g = (capital_g_prime_xsq.karatsuba(&f_minx)).reduce_by_cyclotomic(n);
match babai_reduce(f, g, &mut capital_f, &mut capital_g) {
Ok(_) => Some((capital_f, capital_g)),
Err(_e) => {
#[cfg(test)]
{
panic!("{}", _e);
}
#[cfg(not(test))]
{
None
}
},
}
}
/// Generates a polynomial of degree at most n-1 whose coefficients are distributed according
/// to a discrete Gaussian with mu = 0 and sigma = 1.17 * sqrt(Q / (2n)).
fn gen_poly<R: Rng>(n: usize, rng: &mut R) -> Polynomial<i16> {
let mu = 0.0;
let sigma_star = 1.43300980528773;
Polynomial {
coefficients: (0..4096)
.map(|_| sampler_z(mu, sigma_star, sigma_star - 0.001, rng))
.collect::<Vec<i16>>()
.chunks(4096 / n)
.map(|ch| ch.iter().sum())
.collect(),
}
}
/// Computes the Gram-Schmidt norm of B = [[g, -f], [G, -F]] from f and g.
/// Corresponds to line 9 in algorithm 5 of the spec [1, p.34]
///
/// [1]: https://falcon-sign.info/falcon.pdf
fn gram_schmidt_norm_squared(f: &Polynomial<i16>, g: &Polynomial<i16>) -> f64 {
let n = f.coefficients.len();
let norm_f_squared = f.l2_norm_squared();
let norm_g_squared = g.l2_norm_squared();
let gamma1 = norm_f_squared + norm_g_squared;
let f_fft = f.map(|i| Complex64::new(*i as f64, 0.0)).fft();
let g_fft = g.map(|i| Complex64::new(*i as f64, 0.0)).fft();
let f_adj_fft = f_fft.map(|c| c.conj());
let g_adj_fft = g_fft.map(|c| c.conj());
let ffgg_fft = f_fft.hadamard_mul(&f_adj_fft) + g_fft.hadamard_mul(&g_adj_fft);
let ffgg_fft_inverse = ffgg_fft.hadamard_inv();
let qf_over_ffgg_fft = f_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
let qg_over_ffgg_fft = g_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
let norm_f_over_ffgg_squared =
qf_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
let norm_g_over_ffgg_squared =
qg_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
let gamma2 = norm_f_over_ffgg_squared + norm_g_over_ffgg_squared;
f64::max(gamma1, gamma2)
}
/// Reduces the vector (F,G) relative to (f,g). This method follows the python implementation [1].
/// Note that this algorithm can end up in an infinite loop. (It's one of the things the author
/// would like to fix.) When this happens, control returns an error (hence the return type) and
/// generates another keypair with fresh randomness.
///
/// Algorithm 7 in the spec [2, p.35]
///
/// [1]: https://github.com/tprest/falcon.py
///
/// [2]: https://falcon-sign.info/falcon.pdf
fn babai_reduce(
f: &Polynomial<BigInt>,
g: &Polynomial<BigInt>,
capital_f: &mut Polynomial<BigInt>,
capital_g: &mut Polynomial<BigInt>,
) -> Result<(), String> {
let bitsize = |bi: &BigInt| (bi.bits() + 7) & (u64::MAX ^ 7);
let n = f.coefficients.len();
let size = [
f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
53,
]
.into_iter()
.max()
.unwrap();
let shift = (size as i64) - 53;
let f_adjusted = f
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
.fft();
let g_adjusted = g
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
.fft();
let f_star_adjusted = f_adjusted.map(|c| c.conj());
let g_star_adjusted = g_adjusted.map(|c| c.conj());
let denominator_fft =
f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
let mut counter = 0;
loop {
let capital_size = [
capital_f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
capital_g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
53,
]
.into_iter()
.max()
.unwrap();
if capital_size < size {
break;
}
let capital_shift = (capital_size as i64) - 53;
let capital_f_adjusted = capital_f
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
.fft();
let capital_g_adjusted = capital_g
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
.fft();
let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
+ capital_g_adjusted.hadamard_mul(&g_star_adjusted);
let quotient = numerator.hadamard_div(&denominator_fft).ifft();
let k = quotient.map(|f| Into::<BigInt>::into(f.re.round() as i64));
if k.is_zero() {
break;
}
let kf = (k.clone().karatsuba(f))
.reduce_by_cyclotomic(n)
.map(|bi| bi << (capital_size - size));
let kg = (k.clone().karatsuba(g))
.reduce_by_cyclotomic(n)
.map(|bi| bi << (capital_size - size));
*capital_f -= kf;
*capital_g -= kg;
counter += 1;
if counter > 1000 {
// If we get here, that means that (with high likelihood) we are in an
// infinite loop. We know it happens from time to time -- seldomly, but it
// does. It would be nice to fix that! But in order to fix it we need to be
// able to reproduce it, and for that we need test vectors. So print them
// and hope that one day they circle back to the implementor.
return Err(format!("Encountered infinite loop in babai_reduce of falcon-rust.\n\\
Please help the developer(s) fix it! You can do this by sending them the inputs to the function that caused the behavior:\n\\
f: {:?}\n\\
g: {:?}\n\\
capital_f: {:?}\n\\
capital_g: {:?}\n", f.coefficients, g.coefficients, capital_f.coefficients, capital_g.coefficients));
}
}
Ok(())
}
/// Extended Euclidean algorithm for computing the greatest common divisor (g) and
/// Bézout coefficients (u, v) for the relation
///
/// $$ u a + v b = g . $$
///
/// Implementation adapted from Wikipedia [1].
///
/// [1]: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
fn xgcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
let (mut old_r, mut r) = (a.clone(), b.clone());
let (mut old_s, mut s) = (BigInt::one(), BigInt::zero());
let (mut old_t, mut t) = (BigInt::zero(), BigInt::one());
while r != BigInt::zero() {
let quotient = old_r.clone() / r.clone();
(old_r, r) = (r.clone(), old_r.clone() - quotient.clone() * r);
(old_s, s) = (s.clone(), old_s.clone() - quotient.clone() * s);
(old_t, t) = (t.clone(), old_t.clone() - quotient * t);
}
(old_r, old_s, old_t)
}

View file

@ -1,622 +0,0 @@
use alloc::vec::Vec;
use core::{
default::Default,
fmt::Debug,
ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
};
use num::{One, Zero};
use super::{field::FalconFelt, Inverse};
use crate::{
dsa::rpo_falcon512::{MODULUS, N},
Felt,
};
#[derive(Debug, Clone, Default)]
pub struct Polynomial<F> {
pub coefficients: Vec<F>,
}
impl<F> Polynomial<F>
where
F: Clone,
{
pub fn new(coefficients: Vec<F>) -> Self {
Self { coefficients }
}
}
impl<
F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone + Inverse,
> Polynomial<F>
{
pub fn hadamard_mul(&self, other: &Self) -> Self {
Polynomial::new(
self.coefficients
.iter()
.zip(other.coefficients.iter())
.map(|(a, b)| *a * *b)
.collect(),
)
}
pub fn hadamard_div(&self, other: &Self) -> Self {
let other_coefficients_inverse = F::batch_inverse_or_zero(&other.coefficients);
Polynomial::new(
self.coefficients
.iter()
.zip(other_coefficients_inverse.iter())
.map(|(a, b)| *a * *b)
.collect(),
)
}
pub fn hadamard_inv(&self) -> Self {
let coefficients_inverse = F::batch_inverse_or_zero(&self.coefficients);
Polynomial::new(coefficients_inverse)
}
}
impl<F: Zero + PartialEq + Clone> Polynomial<F> {
pub fn degree(&self) -> Option<usize> {
if self.coefficients.is_empty() {
return None;
}
let mut max_index = self.coefficients.len() - 1;
while self.coefficients[max_index] == F::zero() {
if let Some(new_index) = max_index.checked_sub(1) {
max_index = new_index;
} else {
return None;
}
}
Some(max_index)
}
pub fn lc(&self) -> F {
match self.degree() {
Some(non_negative_degree) => self.coefficients[non_negative_degree].clone(),
None => F::zero(),
}
}
}
/// The following implementations are specific to cyclotomic polynomial rings,
/// i.e., F\[ X \] / <X^n + 1>, and are used extensively in Falcon.
impl<
F: One
+ Zero
+ Clone
+ Neg<Output = F>
+ MulAssign
+ AddAssign
+ Div<Output = F>
+ Sub<Output = F>
+ PartialEq,
> Polynomial<F>
{
/// Reduce the polynomial by X^n + 1.
pub fn reduce_by_cyclotomic(&self, n: usize) -> Self {
let mut coefficients = vec![F::zero(); n];
let mut sign = -F::one();
for (i, c) in self.coefficients.iter().cloned().enumerate() {
if i % n == 0 {
sign *= -F::one();
}
coefficients[i % n] += sign.clone() * c;
}
Polynomial::new(coefficients)
}
/// Computes the field norm of the polynomial as an element of the cyclotomic ring
/// F\[ X \] / <X^n + 1 > relative to one of half the size, i.e., F\[ X \] / <X^(n/2) + 1> .
///
/// Corresponds to formula 3.25 in the spec [1, p.30].
///
/// [1]: https://falcon-sign.info/falcon.pdf
pub fn field_norm(&self) -> Self {
let n = self.coefficients.len();
let mut f0_coefficients = vec![F::zero(); n / 2];
let mut f1_coefficients = vec![F::zero(); n / 2];
for i in 0..n / 2 {
f0_coefficients[i] = self.coefficients[2 * i].clone();
f1_coefficients[i] = self.coefficients[2 * i + 1].clone();
}
let f0 = Polynomial::new(f0_coefficients);
let f1 = Polynomial::new(f1_coefficients);
let f0_squared = (f0.clone() * f0).reduce_by_cyclotomic(n / 2);
let f1_squared = (f1.clone() * f1).reduce_by_cyclotomic(n / 2);
let x = Polynomial::new(vec![F::zero(), F::one()]);
f0_squared - (x * f1_squared).reduce_by_cyclotomic(n / 2)
}
/// Lifts an element from a cyclotomic polynomial ring to one of double the size.
pub fn lift_next_cyclotomic(&self) -> Self {
let n = self.coefficients.len();
let mut coefficients = vec![F::zero(); n * 2];
for i in 0..n {
coefficients[2 * i] = self.coefficients[i].clone();
}
Self::new(coefficients)
}
/// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 >
/// , which corresponds to f(x^2).
pub fn galois_adjoint(&self) -> Self {
Self::new(
self.coefficients
.iter()
.enumerate()
.map(|(i, c)| if i % 2 == 0 { c.clone() } else { c.clone().neg() })
.collect(),
)
}
}
impl<F: Clone + Into<f64>> Polynomial<F> {
pub(crate) fn l2_norm_squared(&self) -> f64 {
self.coefficients
.iter()
.map(|i| Into::<f64>::into(i.clone()))
.map(|i| i * i)
.sum::<f64>()
}
}
impl<F> PartialEq for Polynomial<F>
where
F: Zero + PartialEq + Clone + AddAssign,
{
fn eq(&self, other: &Self) -> bool {
if self.is_zero() && other.is_zero() {
true
} else if self.is_zero() || other.is_zero() {
false
} else {
let self_degree = self.degree().unwrap();
let other_degree = other.degree().unwrap();
self.coefficients[0..=self_degree] == other.coefficients[0..=other_degree]
}
}
}
impl<F> Eq for Polynomial<F> where F: Zero + PartialEq + Clone + AddAssign {}
impl<F> Add for &Polynomial<F>
where
F: Add<Output = F> + AddAssign + Clone,
{
type Output = Polynomial<F>;
fn add(self, rhs: Self) -> Self::Output {
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
let mut coefficients = self.coefficients.clone();
for (i, c) in rhs.coefficients.iter().enumerate() {
coefficients[i] += c.clone();
}
coefficients
} else {
let mut coefficients = rhs.coefficients.clone();
for (i, c) in self.coefficients.iter().enumerate() {
coefficients[i] += c.clone();
}
coefficients
};
Self::Output { coefficients }
}
}
impl<F> Add for Polynomial<F>
where
F: Add<Output = F> + AddAssign + Clone,
{
type Output = Polynomial<F>;
fn add(self, rhs: Self) -> Self::Output {
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
let mut coefficients = self.coefficients.clone();
for (i, c) in rhs.coefficients.into_iter().enumerate() {
coefficients[i] += c;
}
coefficients
} else {
let mut coefficients = rhs.coefficients.clone();
for (i, c) in self.coefficients.into_iter().enumerate() {
coefficients[i] += c;
}
coefficients
};
Self::Output { coefficients }
}
}
impl<F> AddAssign for Polynomial<F>
where
F: Add<Output = F> + AddAssign + Clone,
{
fn add_assign(&mut self, rhs: Self) {
if self.coefficients.len() >= rhs.coefficients.len() {
for (i, c) in rhs.coefficients.into_iter().enumerate() {
self.coefficients[i] += c;
}
} else {
let mut coefficients = rhs.coefficients.clone();
for (i, c) in self.coefficients.iter().enumerate() {
coefficients[i] += c.clone();
}
self.coefficients = coefficients;
}
}
}
impl<F> Sub for &Polynomial<F>
where
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
{
type Output = Polynomial<F>;
fn sub(self, rhs: Self) -> Self::Output {
self + &(-rhs)
}
}
impl<F> Sub for Polynomial<F>
where
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
{
type Output = Polynomial<F>;
fn sub(self, rhs: Self) -> Self::Output {
self + (-rhs)
}
}
impl<F> SubAssign for Polynomial<F>
where
F: Add<Output = F> + Neg<Output = F> + AddAssign + Clone + Sub<Output = F>,
{
fn sub_assign(&mut self, rhs: Self) {
self.coefficients = self.clone().sub(rhs).coefficients;
}
}
impl<F: Neg<Output = F> + Clone> Neg for &Polynomial<F> {
type Output = Polynomial<F>;
fn neg(self) -> Self::Output {
Self::Output {
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
}
}
}
impl<F: Neg<Output = F> + Clone> Neg for Polynomial<F> {
type Output = Self;
fn neg(self) -> Self::Output {
Self::Output {
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
}
}
}
impl<F> Mul for &Polynomial<F>
where
F: Add + AddAssign + Mul<Output = F> + Sub<Output = F> + Zero + PartialEq + Clone,
{
type Output = Polynomial<F>;
fn mul(self, other: Self) -> Self::Output {
if self.is_zero() || other.is_zero() {
return Polynomial::<F>::zero();
}
let mut coefficients =
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
for i in 0..self.coefficients.len() {
for j in 0..other.coefficients.len() {
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
}
}
Polynomial { coefficients }
}
}
impl<F> Mul for Polynomial<F>
where
F: Add + AddAssign + Mul<Output = F> + Zero + PartialEq + Clone,
{
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
if self.is_zero() || other.is_zero() {
return Self::zero();
}
let mut coefficients =
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
for i in 0..self.coefficients.len() {
for j in 0..other.coefficients.len() {
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
}
}
Self { coefficients }
}
}
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for &Polynomial<F> {
type Output = Polynomial<F>;
fn mul(self, other: F) -> Self::Output {
Polynomial {
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
}
}
}
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for Polynomial<F> {
type Output = Polynomial<F>;
fn mul(self, other: F) -> Self::Output {
Polynomial {
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
}
}
}
impl<F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone>
Polynomial<F>
{
/// Multiply two polynomials using Karatsuba's divide-and-conquer algorithm.
pub fn karatsuba(&self, other: &Self) -> Self {
Polynomial::new(vector_karatsuba(&self.coefficients, &other.coefficients))
}
}
impl<F> One for Polynomial<F>
where
F: Clone + One + PartialEq + Zero + AddAssign,
{
fn one() -> Self {
Self { coefficients: vec![F::one()] }
}
}
impl<F> Zero for Polynomial<F>
where
F: Zero + PartialEq + Clone + AddAssign,
{
fn zero() -> Self {
Self { coefficients: vec![] }
}
fn is_zero(&self) -> bool {
self.degree().is_none()
}
}
impl<F: Zero + Clone> Polynomial<F> {
pub fn shift(&self, shamt: usize) -> Self {
Self {
coefficients: [vec![F::zero(); shamt], self.coefficients.clone()].concat(),
}
}
pub fn constant(f: F) -> Self {
Self { coefficients: vec![f] }
}
pub fn map<G: Clone, C: FnMut(&F) -> G>(&self, closure: C) -> Polynomial<G> {
Polynomial::<G>::new(self.coefficients.iter().map(closure).collect())
}
pub fn fold<G, C: FnMut(G, &F) -> G + Clone>(&self, mut initial_value: G, closure: C) -> G {
for c in self.coefficients.iter() {
initial_value = (closure.clone())(initial_value, c);
}
initial_value
}
}
impl<F> Div<Polynomial<F>> for Polynomial<F>
where
F: Zero
+ One
+ PartialEq
+ AddAssign
+ Clone
+ Mul<Output = F>
+ MulAssign
+ Div<Output = F>
+ Neg<Output = F>
+ Sub<Output = F>,
{
type Output = Polynomial<F>;
fn div(self, denominator: Self) -> Self::Output {
if denominator.is_zero() {
panic!();
}
if self.is_zero() {
Self::zero();
}
let mut remainder = self.clone();
let mut quotient = Polynomial::<F>::zero();
while remainder.degree().unwrap() >= denominator.degree().unwrap() {
let shift = remainder.degree().unwrap() - denominator.degree().unwrap();
let quotient_coefficient = remainder.lc() / denominator.lc();
let monomial = Self::constant(quotient_coefficient).shift(shift);
quotient += monomial.clone();
remainder -= monomial * denominator.clone();
if remainder.is_zero() {
break;
}
}
quotient
}
}
fn vector_karatsuba<
F: Zero + AddAssign + Mul<Output = F> + Sub<Output = F> + Div<Output = F> + Clone,
>(
left: &[F],
right: &[F],
) -> Vec<F> {
let n = left.len();
if n <= 8 {
let mut product = vec![F::zero(); left.len() + right.len() - 1];
for (i, l) in left.iter().enumerate() {
for (j, r) in right.iter().enumerate() {
product[i + j] += l.clone() * r.clone();
}
}
return product;
}
let n_over_2 = n / 2;
let mut product = vec![F::zero(); 2 * n - 1];
let left_lo = &left[0..n_over_2];
let right_lo = &right[0..n_over_2];
let left_hi = &left[n_over_2..];
let right_hi = &right[n_over_2..];
let left_sum: Vec<F> =
left_lo.iter().zip(left_hi).map(|(a, b)| a.clone() + b.clone()).collect();
let right_sum: Vec<F> =
right_lo.iter().zip(right_hi).map(|(a, b)| a.clone() + b.clone()).collect();
let prod_lo = vector_karatsuba(left_lo, right_lo);
let prod_hi = vector_karatsuba(left_hi, right_hi);
let prod_mid: Vec<F> = vector_karatsuba(&left_sum, &right_sum)
.iter()
.zip(prod_lo.iter().zip(prod_hi.iter()))
.map(|(s, (l, h))| s.clone() - (l.clone() + h.clone()))
.collect();
for (i, l) in prod_lo.into_iter().enumerate() {
product[i] = l;
}
for (i, m) in prod_mid.into_iter().enumerate() {
product[i + n_over_2] += m;
}
for (i, h) in prod_hi.into_iter().enumerate() {
product[i + n] += h
}
product
}
impl From<Polynomial<FalconFelt>> for Polynomial<Felt> {
fn from(item: Polynomial<FalconFelt>) -> Self {
let res: Vec<Felt> =
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
Polynomial::new(res)
}
}
impl From<&Polynomial<FalconFelt>> for Polynomial<Felt> {
fn from(item: &Polynomial<FalconFelt>) -> Self {
let res: Vec<Felt> =
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
Polynomial::new(res)
}
}
impl From<Polynomial<i16>> for Polynomial<FalconFelt> {
fn from(item: Polynomial<i16>) -> Self {
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl From<&Polynomial<i16>> for Polynomial<FalconFelt> {
fn from(item: &Polynomial<i16>) -> Self {
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl From<Vec<i16>> for Polynomial<FalconFelt> {
fn from(item: Vec<i16>) -> Self {
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl From<&Vec<i16>> for Polynomial<FalconFelt> {
fn from(item: &Vec<i16>) -> Self {
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl Polynomial<FalconFelt> {
pub fn norm_squared(&self) -> u64 {
self.coefficients
.iter()
.map(|&i| i.balanced_value() as i64)
.map(|i| (i * i) as u64)
.sum::<u64>()
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the coefficients of this polynomial as field elements.
pub fn to_elements(&self) -> Vec<Felt> {
self.coefficients.iter().map(|&a| Felt::from(a.value() as u16)).collect()
}
// POLYNOMIAL OPERATIONS
// --------------------------------------------------------------------------------------------
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
/// than the Miden prime.
///
/// Note that this multiplication is not over Z_p\[x\]/(phi).
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
let mut c = [0; 2 * N];
for i in 0..N {
for j in 0..N {
c[i + j] += a.coefficients[i].value() as u64 * b.coefficients[j].value() as u64;
}
}
c
}
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
let mut c = [FalconFelt::zero(); N];
let modulus = MODULUS as u16;
for i in 0..N {
let ai = a[N + i] % modulus as u64;
let neg_ai = (modulus - ai as u16) % modulus;
let bi = (a[i] % modulus as u64) as u16;
c[i] = FalconFelt::new(((neg_ai + bi) % modulus) as i16);
}
Self::new(c.to_vec())
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{FalconFelt, Polynomial, N};
#[test]
fn test_negacyclic_reduction() {
let coef1: [u8; N] = rand_utils::rand_array();
let coef2: [u8; N] = rand_utils::rand_array();
let poly1 = Polynomial::new(coef1.iter().map(|&a| FalconFelt::new(a as i16)).collect());
let poly2 = Polynomial::new(coef2.iter().map(|&a| FalconFelt::new(a as i16)).collect());
let prod = poly1.clone() * poly2.clone();
assert_eq!(
prod.reduce_by_cyclotomic(N),
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
);
}
}

View file

@ -1,299 +0,0 @@
use core::f64::consts::LN_2;
#[cfg(not(feature = "std"))]
use num::Float;
use rand::Rng;
/// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to
/// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation
/// equal to sigma_max.
fn base_sampler(bytes: [u8; 9]) -> i16 {
const RCDT: [u128; 18] = [
3024686241123004913666,
1564742784480091954050,
636254429462080897535,
199560484645026482916,
47667343854657281903,
8595902006365044063,
1163297957344668388,
117656387352093658,
8867391802663976,
496969357462633,
20680885154299,
638331848991,
14602316184,
247426747,
3104126,
28824,
198,
1,
];
let u = u128::from_be_bytes([vec![0u8; 7], bytes.to_vec()].concat().try_into().unwrap());
RCDT.into_iter().filter(|r| u < *r).count() as i16
}
/// Computes an integer approximation of 2^63 * ccs * exp(-x).
fn approx_exp(x: f64, ccs: f64) -> u64 {
// The constants C are used to approximate exp(-x); these
// constants are taken from FACCT (up to a scaling factor
// of 2^63):
// https://eprint.iacr.org/2018/1234
// https://github.com/raykzhao/gaussian
const C: [u64; 13] = [
0x00000004741183a3u64,
0x00000036548cfc06u64,
0x0000024fdcbf140au64,
0x0000171d939de045u64,
0x0000d00cf58f6f84u64,
0x000680681cf796e3u64,
0x002d82d8305b0feau64,
0x011111110e066fd0u64,
0x0555555555070f00u64,
0x155555555581ff00u64,
0x400000000002b400u64,
0x7fffffffffff4800u64,
0x8000000000000000u64,
];
let mut z: u64;
let mut y: u64;
let twoe63 = 1u64 << 63;
y = C[0];
z = f64::floor(x * (twoe63 as f64)) as u64;
for cu in C.iter().skip(1) {
let zy = (z as u128) * (y as u128);
y = cu - ((zy >> 63) as u64);
}
z = f64::floor((twoe63 as f64) * ccs) as u64;
(((z as u128) * (y as u128)) >> 63) as u64
}
/// A random bool that is true with probability ≈ ccs · exp(-x).
fn ber_exp(x: f64, ccs: f64, random_bytes: [u8; 7]) -> bool {
// 0.69314718055994530941 = ln(2)
let s = f64::floor(x / LN_2) as usize;
let r = x - LN_2 * (s as f64);
let shamt = usize::min(s, 63);
let z = ((((approx_exp(r, ccs) as u128) << 1) - 1) >> shamt) as u64;
let mut w = 0i16;
for (index, i) in (0..64).step_by(8).rev().enumerate() {
let byte = random_bytes[index];
w = (byte as i16) - (((z >> i) & 0xff) as i16);
if w != 0 {
break;
}
}
w < 0
}
/// Samples an integer from the Gaussian distribution with given mean (mu) and standard deviation
/// (sigma).
pub(crate) fn sampler_z<R: Rng>(mu: f64, sigma: f64, sigma_min: f64, rng: &mut R) -> i16 {
const SIGMA_MAX: f64 = 1.8205;
const INV_2SIGMA_MAX_SQ: f64 = 1f64 / (2f64 * SIGMA_MAX * SIGMA_MAX);
let isigma = 1f64 / sigma;
let dss = 0.5f64 * isigma * isigma;
let s = f64::floor(mu);
let r = mu - s;
let ccs = sigma_min * isigma;
loop {
let z0 = base_sampler(rng.gen());
let random_byte: u8 = rng.gen();
let b = (random_byte & 1) as i16;
let z = b + ((b << 1) - 1) * z0;
let zf_min_r = (z as f64) - r;
// x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
let x = zf_min_r * zf_min_r * dss - (z0 * z0) as f64 * INV_2SIGMA_MAX_SQ;
if ber_exp(x, ccs, rng.gen()) {
return z + (s as i16);
}
}
}
#[cfg(all(test, feature = "std"))]
mod test {
use alloc::vec::Vec;
use std::{thread::sleep, time::Duration};
use rand::RngCore;
use super::{approx_exp, ber_exp, sampler_z};
/// RNG used only for testing purposes, whereby the produced
/// string of random bytes is equal to the one it is initialized
/// with. Whatever you do, do not use this RNG in production.
struct UnsafeBufferRng {
buffer: Vec<u8>,
index: usize,
}
impl UnsafeBufferRng {
fn new(buffer: &[u8]) -> Self {
Self { buffer: buffer.to_vec(), index: 0 }
}
fn next(&mut self) -> u8 {
if self.buffer.len() <= self.index {
// panic!("Ran out of buffer.");
sleep(Duration::from_millis(10));
0u8
} else {
let return_value = self.buffer[self.index];
self.index += 1;
return_value
}
}
}
impl RngCore for UnsafeBufferRng {
fn next_u32(&mut self) -> u32 {
// let bytes: [u8; 4] = (0..4)
// .map(|_| self.next())
// .collect_vec()
// .try_into()
// .unwrap();
// u32::from_be_bytes(bytes)
u32::from_le_bytes([self.next(), 0, 0, 0])
}
fn next_u64(&mut self) -> u64 {
// let bytes: [u8; 8] = (0..8)
// .map(|_| self.next())
// .collect_vec()
// .try_into()
// .unwrap();
// u64::from_be_bytes(bytes)
u64::from_le_bytes([self.next(), 0, 0, 0, 0, 0, 0, 0])
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
for d in dest.iter_mut() {
*d = self.next();
}
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
for d in dest.iter_mut() {
*d = self.next();
}
Ok(())
}
}
#[test]
fn test_unsafe_buffer_rng() {
let seed_bytes = hex::decode("7FFECD162AE2").unwrap();
let mut rng = UnsafeBufferRng::new(&seed_bytes);
let generated_bytes: Vec<u8> = (0..seed_bytes.len()).map(|_| rng.next()).collect();
assert_eq!(seed_bytes, generated_bytes);
}
#[test]
fn test_approx_exp() {
let precision = 1u64 << 14;
// known answers were generated with the following sage script:
//```sage
// num_samples = 10
// precision = 200
// R = Reals(precision)
//
// print(f"let kats : [(f64, f64, u64);{num_samples}] = [")
// for i in range(num_samples):
// x = RDF.random_element(0.0, 0.693147180559945)
// ccs = RDF.random_element(0.0, 1.0)
// res = round(2^63 * R(ccs) * exp(R(-x)))
// print(f"({x}, {ccs}, {res}),")
// print("];")
// ```
let kats: [(f64, f64, u64); 10] = [
(0.2314993926072656, 0.8148006314615972, 5962140072160879737),
(0.2648875572812225, 0.12769669655309035, 903712282351034505),
(0.11251957513682391, 0.9264611470305881, 7635725498677341553),
(0.04353439307256617, 0.5306497137523327, 4685877322232397936),
(0.41834495299784347, 0.879438856118578, 5338392138535350986),
(0.32579398973228557, 0.16513412873289002, 1099603299296456803),
(0.5939508073919817, 0.029776019144967303, 151637565622779016),
(0.2932367999399056, 0.37123847662857923, 2553827649386670452),
(0.5005699297417507, 0.31447208863888976, 1758235618083658825),
(0.4876437338498085, 0.6159515298936868, 3488632981903743976),
];
for (x, ccs, answer) in kats {
let difference = (answer as i128) - (approx_exp(x, ccs) as i128);
assert!(
(difference * difference) as u64 <= precision * precision,
"answer: {answer} versus approximation: {}\ndifference: {} whereas precision: {}",
approx_exp(x, ccs),
difference,
precision
);
}
}
#[test]
fn test_ber_exp() {
let kats = [
(
1.268_314_048_020_498_4,
0.749_990_853_267_664_9,
hex::decode("ea000000000000").unwrap(),
false,
),
(
0.001_563_917_959_143_409_6,
0.749_990_853_267_664_9,
hex::decode("6c000000000000").unwrap(),
true,
),
(
0.017_921_215_753_999_235,
0.749_990_853_267_664_9,
hex::decode("c2000000000000").unwrap(),
false,
),
(
0.776_117_648_844_980_6,
0.751_181_554_542_520_8,
hex::decode("58000000000000").unwrap(),
true,
),
];
for (x, ccs, bytes, answer) in kats {
assert_eq!(answer, ber_exp(x, ccs, bytes.try_into().unwrap()));
}
}
#[test]
fn test_sampler_z() {
let sigma_min = 1.277833697;
// known answers from the doc, table 3.2, page 44
// https://falcon-sign.info/falcon.pdf
// The zeros were added to account for dropped bytes.
let kats = [
(-91.90471153063714,1.7037990414754918,hex::decode("0fc5442ff043d66e91d1ea000000000000cac64ea5450a22941edc6c").unwrap(),-92),
(-8.322564895434937,1.7037990414754918,hex::decode("f4da0f8d8444d1a77265c2000000000000ef6f98bbbb4bee7db8d9b3").unwrap(),-8),
(-19.096516109216804,1.7035823083824078,hex::decode("db47f6d7fb9b19f25c36d6000000000000b9334d477a8bc0be68145d").unwrap(),-20),
(-11.335543982423326, 1.7035823083824078, hex::decode("ae41b4f5209665c74d00dc000000000000c1a8168a7bb516b3190cb42c1ded26cd52000000000000aed770eca7dd334e0547bcc3c163ce0b").unwrap(), -12),
(7.9386734193997555, 1.6984647769450156, hex::decode("31054166c1012780c603ae0000000000009b833cec73f2f41ca5807c000000000000c89c92158834632f9b1555").unwrap(), 8),
(-28.990850086867255, 1.6984647769450156, hex::decode("737e9d68a50a06dbbc6477").unwrap(), -30),
(-9.071257914091655, 1.6980782114808988, hex::decode("a98ddd14bf0bf22061d632").unwrap(), -10),
(-43.88754568839566, 1.6980782114808988, hex::decode("3cbf6818a68f7ab9991514").unwrap(), -41),
(-58.17435547946095,1.7010983419195522,hex::decode("6f8633f5bfa5d26848668e0000000000003d5ddd46958e97630410587c").unwrap(),-61),
(-43.58664906684732, 1.7010983419195522, hex::decode("272bc6c25f5c5ee53f83c40000000000003a361fbc7cc91dc783e20a").unwrap(), -46),
(-34.70565203313315, 1.7009387219711465, hex::decode("45443c59574c2c3b07e2e1000000000000d9071e6d133dbe32754b0a").unwrap(), -34),
(-44.36009577368896, 1.7009387219711465, hex::decode("6ac116ed60c258e2cbaeab000000000000728c4823e6da36e18d08da0000000000005d0cc104e21cc7fd1f5ca8000000000000d9dbb675266c928448059e").unwrap(), -44),
(-21.783037079346236, 1.6958406126012802, hex::decode("68163bc1e2cbf3e18e7426").unwrap(), -23),
(-39.68827784633828, 1.6958406126012802, hex::decode("d6a1b51d76222a705a0259").unwrap(), -40),
(-18.488607061056847, 1.6955259305261838, hex::decode("f0523bfaa8a394bf4ea5c10000000000000f842366fde286d6a30803").unwrap(), -22),
(-48.39610939101591, 1.6955259305261838, hex::decode("87bd87e63374cee62127fc0000000000006931104aab64f136a0485b").unwrap(), -50),
];
for (mu, sigma, random_bytes, answer) in kats {
assert_eq!(
sampler_z(mu, sigma, sigma_min, &mut UnsafeBufferRng::new(&random_bytes)),
answer
);
}
}
}

View file

@ -1,105 +0,0 @@
use crate::{
hash::rpo::Rpo256,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
Felt, Word, ZERO,
};
mod hash_to_point;
mod keys;
mod math;
mod signature;
pub use self::{
keys::{PubKeyPoly, PublicKey, SecretKey},
math::Polynomial,
signature::{Signature, SignatureHeader, SignaturePoly},
};
// CONSTANTS
// ================================================================================================
// The Falcon modulus p.
const MODULUS: i16 = 12289;
// Number of bits needed to encode an element in the Falcon field.
const FALCON_ENCODING_BITS: u32 = 14;
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
// defining the ring Z_p[x]/(phi).
const N: usize = 512;
const LOG_N: u8 = 9;
/// Length of nonce used for key-pair generation.
const SIG_NONCE_LEN: usize = 40;
/// Number of filed elements used to encode a nonce.
const NONCE_ELEMENTS: usize = 8;
/// Public key length as a u8 vector.
pub const PK_LEN: usize = 897;
/// Secret key length as a u8 vector.
pub const SK_LEN: usize = 1281;
/// Signature length as a u8 vector.
const SIG_POLY_BYTE_LEN: usize = 625;
/// Bound on the squared-norm of the signature.
const SIG_L2_BOUND: u64 = 34034726;
/// Standard deviation of the Gaussian over the lattice.
const SIGMA: f64 = 165.7366171829776;
// TYPE ALIASES
// ================================================================================================
type ShortLatticeBasis = [Polynomial<i16>; 4];
// NONCE
// ================================================================================================
/// Nonce of the Falcon signature.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Nonce([u8; SIG_NONCE_LEN]);
impl Nonce {
/// Returns a new [Nonce] instantiated from the provided bytes.
pub fn new(bytes: [u8; SIG_NONCE_LEN]) -> Self {
Self(bytes)
}
/// Returns the underlying bytes of this nonce.
pub fn as_bytes(&self) -> &[u8; SIG_NONCE_LEN] {
&self.0
}
/// Converts byte representation of the nonce into field element representation.
///
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
/// of the nonce and interpreting them as field elements.
pub fn to_elements(&self) -> [Felt; NONCE_ELEMENTS] {
let mut buffer = [0_u8; 8];
let mut result = [ZERO; 8];
for (i, bytes) in self.0.chunks(5).enumerate() {
buffer[..5].copy_from_slice(bytes);
// we can safely (without overflow) create a new Felt from u64 value here since this
// value contains at most 5 bytes
result[i] = Felt::new(u64::from_le_bytes(buffer));
}
result
}
}
impl Serializable for &Nonce {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.0)
}
}
impl Deserializable for Nonce {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let bytes = source.read()?;
Ok(Self(bytes))
}
}

View file

@ -1,375 +0,0 @@
use alloc::{string::ToString, vec::Vec};
use core::ops::Deref;
use num::Zero;
use super::{
hash_to_point::hash_to_point_rpo256,
keys::PubKeyPoly,
math::{FalconFelt, FastFft, Polynomial},
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256,
Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN,
};
// FALCON SIGNATURE
// ================================================================================================
/// An RPO Falcon512 signature over a message.
///
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2 a nonce `r`, and a public
/// key polynomial `h` where:
/// - p := 12289
/// - phi := x^512 + 1
///
/// The signature verifies against a public key `pk` if and only if:
/// 1. s1 = c - s2 * h
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
///
/// where |.| is the norm and:
/// - c = HashToPoint(r || message)
/// - pk = Rpo256::hash(h)
///
/// Here h is a polynomial representing the public key and pk is its digest using the Rpo256 hash
/// function. c is a polynomial that is the hash-to-point of the message being signed.
///
/// The polynomial h is serialized as:
/// 1. 1 byte representing the log2(512) i.e., 9.
/// 2. 896 bytes for the public key itself.
///
/// The signature is serialized as:
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
/// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header
/// byte is set to `10111001` which differentiates it from the standardized instantiation of the
/// Falcon signature.
/// 2. 40 bytes for the nonce.
/// 4. 625 bytes encoding the `s2` polynomial above.
///
/// The total size of the signature (including the extended public key) is 1563 bytes.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
header: SignatureHeader,
nonce: Nonce,
s2: SignaturePoly,
h: PubKeyPoly,
}
impl Signature {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
pub fn new(nonce: Nonce, h: PubKeyPoly, s2: SignaturePoly) -> Signature {
Self {
header: SignatureHeader::default(),
nonce,
s2,
h,
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the public key polynomial h.
pub fn pk_poly(&self) -> &PubKeyPoly {
&self.h
}
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
pub fn sig_poly(&self) -> &Polynomial<FalconFelt> {
&self.s2
}
/// Returns the nonce component of the signature.
pub fn nonce(&self) -> &Nonce {
&self.nonce
}
// SIGNATURE VERIFICATION
// --------------------------------------------------------------------------------------------
/// Returns true if this signature is a valid signature for the specified message generated
/// against the secret key matching the specified public key commitment.
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
// compute the hash of the public key polynomial
let h_felt: Polynomial<Felt> = (&**self.pk_poly()).into();
let h_digest: Word = Rpo256::hash_elements(&h_felt.coefficients).into();
if h_digest != pubkey_com {
return false;
}
let c = hash_to_point_rpo256(message, &self.nonce);
verify_helper(&c, &self.s2, self.pk_poly())
}
}
impl Serializable for Signature {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(&self.header);
target.write(&self.nonce);
target.write(&self.s2);
target.write(&self.h);
}
}
impl Deserializable for Signature {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let header = source.read()?;
let nonce = source.read()?;
let s2 = source.read()?;
let h = source.read()?;
Ok(Self { header, nonce, s2, h })
}
}
// SIGNATURE HEADER
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignatureHeader(u8);
impl Default for SignatureHeader {
/// According to section 3.11.3 in the specification [1], the signature header has the format
/// `0cc1nnnn` where:
///
/// 1. `cc` signifies the encoding method. `01` denotes using the compression encoding method
/// and `10` denotes encoding using the uncompressed method.
/// 2. `nnnn` encodes `LOG_N`.
///
/// For RPO Falcon 512 we use compression encoding and N = 512. Moreover, to differentiate the
/// RPO Falcon variant from the reference variant using SHAKE256, we flip the first bit in the
/// header. Thus, for RPO Falcon 512 the header is `10111001`
///
/// [1]: https://falcon-sign.info/falcon.pdf
fn default() -> Self {
Self(0b1011_1001)
}
}
impl Serializable for &SignatureHeader {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.0)
}
}
impl Deserializable for SignatureHeader {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let header = source.read_u8()?;
let (encoding, log_n) = (header >> 4, header & 0b00001111);
if encoding != 0b1011 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: not supported encoding algorithm".to_string(),
));
}
if log_n != LOG_N {
return Err(DeserializationError::InvalidValue(
format!("Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided")
));
}
Ok(Self(header))
}
}
// SIGNATURE POLYNOMIAL
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignaturePoly(pub Polynomial<FalconFelt>);
impl Deref for SignaturePoly {
type Target = Polynomial<FalconFelt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Polynomial<FalconFelt>> for SignaturePoly {
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
Self(pk_poly)
}
}
impl TryFrom<&[i16; N]> for SignaturePoly {
type Error = ();
fn try_from(coefficients: &[i16; N]) -> Result<Self, Self::Error> {
if are_coefficients_valid(coefficients) {
Ok(Self(coefficients.to_vec().into()))
} else {
Err(())
}
}
}
impl Serializable for &SignaturePoly {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let sig_coeff: Vec<i16> = self.0.coefficients.iter().map(|a| a.balanced_value()).collect();
let mut sk_bytes = vec![0_u8; SIG_POLY_BYTE_LEN];
let mut acc = 0;
let mut acc_len = 0;
let mut v = 0;
let mut t;
let mut w;
// For each coefficient of x:
// - the sign is encoded on 1 bit
// - the 7 lower bits are encoded naively (binary)
// - the high bits are encoded in unary encoding
//
// Algorithm 17 p. 47 of the specification [1].
//
// [1]: https://falcon-sign.info/falcon.pdf
for &c in sig_coeff.iter() {
acc <<= 1;
t = c;
if t < 0 {
t = -t;
acc |= 1;
}
w = t as u16;
acc <<= 7;
let mask = 127_u32;
acc |= (w as u32) & mask;
w >>= 7;
acc_len += 8;
acc <<= w + 1;
acc |= 1;
acc_len += w + 1;
while acc_len >= 8 {
acc_len -= 8;
sk_bytes[v] = (acc >> acc_len) as u8;
v += 1;
}
}
if acc_len > 0 {
sk_bytes[v] = (acc << (8 - acc_len)) as u8;
}
target.write_bytes(&sk_bytes);
}
}
impl Deserializable for SignaturePoly {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let input = source.read_array::<SIG_POLY_BYTE_LEN>()?;
let mut input_idx = 0;
let mut acc = 0u32;
let mut acc_len = 0;
let mut coefficients = [FalconFelt::zero(); N];
// Algorithm 18 p. 48 of the specification [1].
//
// [1]: https://falcon-sign.info/falcon.pdf
for c in coefficients.iter_mut() {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
let b = acc >> acc_len;
let s = b & 128;
let mut m = b & 127;
loop {
if acc_len == 0 {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
acc_len = 8;
}
acc_len -= 1;
if ((acc >> acc_len) & 1) != 0 {
break;
}
m += 128;
if m >= 2048 {
return Err(DeserializationError::InvalidValue(format!(
"Failed to decode signature: high bits {m} exceed 2048",
)));
}
}
if s != 0 && m == 0 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: -0 is forbidden".to_string(),
));
}
let felt = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
*c = FalconFelt::new(felt as i16);
}
if (acc & ((1 << acc_len) - 1)) != 0 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: Non-zero unused bits in the last byte".to_string(),
));
}
Ok(Polynomial::new(coefficients.to_vec()).into())
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Takes the hash-to-point polynomial `c` of a message, the signature polynomial over
/// the message `s2` and a public key polynomial and returns `true` is the signature is a valid
/// signature for the given parameters, otherwise it returns `false`.
fn verify_helper(c: &Polynomial<FalconFelt>, s2: &SignaturePoly, h: &PubKeyPoly) -> bool {
let h_fft = h.fft();
let s2_fft = s2.fft();
let c_fft = c.fft();
// compute the signature polynomial s1 using s1 = c - s2 * h
let s1_fft = c_fft - s2_fft.hadamard_mul(&h_fft);
let s1 = s1_fft.ifft();
// compute the norm squared of (s1, s2)
let length_squared_s1 = s1.norm_squared();
let length_squared_s2 = s2.norm_squared();
let length_squared = length_squared_s1 + length_squared_s2;
length_squared < SIG_L2_BOUND
}
/// Checks whether a set of coefficients is a valid one for a signature polynomial.
fn are_coefficients_valid(x: &[i16]) -> bool {
if x.len() != N {
return false;
}
for &c in x {
if !(-2047..=2047).contains(&c) {
return false;
}
}
true
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use super::{super::SecretKey, *};
#[test]
fn test_serialization_round_trip() {
let seed = [0_u8; 32];
let mut rng = ChaCha20Rng::from_seed(seed);
let sk = SecretKey::with_rng(&mut rng);
let signature = sk.sign_with_rng(Word::default(), &mut rng);
let serialized = signature.to_bytes();
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
}
}

View file

@ -1,383 +0,0 @@
use alloc::{string::String, vec::Vec};
use core::{
mem::{size_of, transmute, transmute_copy},
ops::Deref,
slice::{self, from_raw_parts},
};
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
use crate::utils::{
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
};
#[cfg(test)]
mod tests;
// CONSTANTS
// ================================================================================================
const DIGEST32_BYTES: usize = 32;
const DIGEST24_BYTES: usize = 24;
const DIGEST20_BYTES: usize = 20;
// BLAKE3 N-BIT OUTPUT
// ================================================================================================
/// N-bytes output of a blake3 function.
///
/// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32
/// bytes.
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct Blake3Digest<const N: usize>([u8; N]);
impl<const N: usize> Blake3Digest<N> {
pub fn digests_as_bytes(digests: &[Blake3Digest<N>]) -> &[u8] {
let p = digests.as_ptr();
let len = digests.len() * N;
unsafe { slice::from_raw_parts(p as *const u8, len) }
}
}
impl<const N: usize> Default for Blake3Digest<N> {
fn default() -> Self {
Self([0; N])
}
}
impl<const N: usize> Deref for Blake3Digest<N> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<const N: usize> From<Blake3Digest<N>> for [u8; N] {
fn from(value: Blake3Digest<N>) -> Self {
value.0
}
}
impl<const N: usize> From<[u8; N]> for Blake3Digest<N> {
fn from(value: [u8; N]) -> Self {
Self(value)
}
}
impl<const N: usize> From<Blake3Digest<N>> for String {
fn from(value: Blake3Digest<N>) -> Self {
bytes_to_hex_string(value.as_bytes())
}
}
impl<const N: usize> TryFrom<&str> for Blake3Digest<N> {
type Error = HexParseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes(value).map(|v| v.into())
}
}
impl<const N: usize> Serializable for Blake3Digest<N> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.0);
}
}
impl<const N: usize> Deserializable for Blake3Digest<N> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
source.read_array().map(Self)
}
}
impl<const N: usize> Digest for Blake3Digest<N> {
fn as_bytes(&self) -> [u8; 32] {
// compile-time assertion
assert!(N <= 32, "digest currently supports only 32 bytes!");
expand_bytes(&self.0)
}
}
// BLAKE3 256-BIT OUTPUT
// ================================================================================================
/// 256-bit output blake3 hasher.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Blake3_256;
impl Hasher for Blake3_256 {
/// Blake3 collision resistance is 128-bits for 32-bytes output.
const COLLISION_RESISTANCE: u32 = 128;
type Digest = Blake3Digest<32>;
fn hash(bytes: &[u8]) -> Self::Digest {
Blake3Digest(blake3::hash(bytes).into())
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
Self::hash(prepare_merge(values))
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Blake3Digest(blake3::hash(Blake3Digest::digests_as_bytes(values)).into())
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0);
hasher.update(&value.to_le_bytes());
Blake3Digest(hasher.finalize().into())
}
}
impl ElementHasher for Blake3_256 {
type BaseField = Felt;
fn hash_elements<E>(elements: &[E]) -> Self::Digest
where
E: FieldElement<BaseField = Self::BaseField>,
{
Blake3Digest(hash_elements(elements))
}
}
impl Blake3_256 {
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> Blake3Digest<DIGEST32_BYTES> {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[Blake3Digest<DIGEST32_BYTES>; 2]) -> Blake3Digest<DIGEST32_BYTES> {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E>(elements: &[E]) -> Blake3Digest<DIGEST32_BYTES>
where
E: FieldElement<BaseField = Felt>,
{
<Self as ElementHasher>::hash_elements(elements)
}
}
// BLAKE3 192-BIT OUTPUT
// ================================================================================================
/// 192-bit output blake3 hasher.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Blake3_192;
impl Hasher for Blake3_192 {
/// Blake3 collision resistance is 96-bits for 24-bytes output.
const COLLISION_RESISTANCE: u32 = 96;
type Digest = Blake3Digest<24>;
fn hash(bytes: &[u8]) -> Self::Digest {
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
Self::hash(prepare_merge(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0);
hasher.update(&value.to_le_bytes());
Blake3Digest(*shrink_bytes(&hasher.finalize().into()))
}
}
impl ElementHasher for Blake3_192 {
type BaseField = Felt;
fn hash_elements<E>(elements: &[E]) -> Self::Digest
where
E: FieldElement<BaseField = Self::BaseField>,
{
Blake3Digest(hash_elements(elements))
}
}
impl Blake3_192 {
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> Blake3Digest<DIGEST24_BYTES> {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[Blake3Digest<DIGEST24_BYTES>; 2]) -> Blake3Digest<DIGEST24_BYTES> {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E>(elements: &[E]) -> Blake3Digest<DIGEST24_BYTES>
where
E: FieldElement<BaseField = Felt>,
{
<Self as ElementHasher>::hash_elements(elements)
}
}
// BLAKE3 160-BIT OUTPUT
// ================================================================================================
/// 160-bit output blake3 hasher.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Blake3_160;
impl Hasher for Blake3_160 {
/// Blake3 collision resistance is 80-bits for 20-bytes output.
const COLLISION_RESISTANCE: u32 = 80;
type Digest = Blake3Digest<20>;
fn hash(bytes: &[u8]) -> Self::Digest {
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
Self::hash(prepare_merge(values))
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0);
hasher.update(&value.to_le_bytes());
Blake3Digest(*shrink_bytes(&hasher.finalize().into()))
}
}
impl ElementHasher for Blake3_160 {
type BaseField = Felt;
fn hash_elements<E>(elements: &[E]) -> Self::Digest
where
E: FieldElement<BaseField = Self::BaseField>,
{
Blake3Digest(hash_elements(elements))
}
}
impl Blake3_160 {
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> Blake3Digest<DIGEST20_BYTES> {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[Blake3Digest<DIGEST20_BYTES>; 2]) -> Blake3Digest<DIGEST20_BYTES> {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E>(elements: &[E]) -> Blake3Digest<DIGEST20_BYTES>
where
E: FieldElement<BaseField = Felt>,
{
<Self as ElementHasher>::hash_elements(elements)
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Zero-copy ref shrink to array.
fn shrink_bytes<const M: usize, const N: usize>(bytes: &[u8; M]) -> &[u8; N] {
// compile-time assertion
assert!(M >= N, "N should fit in M so it can be safely transmuted into a smaller slice!");
// safety: bytes len is asserted
unsafe { transmute(bytes) }
}
/// Hash the elements into bytes and shrink the output.
fn hash_elements<const N: usize, E>(elements: &[E]) -> [u8; N]
where
E: FieldElement<BaseField = Felt>,
{
// don't leak assumptions from felt and check its actual implementation.
// this is a compile-time branch so it is for free
let digest = if Felt::IS_CANONICAL {
blake3::hash(E::elements_as_bytes(elements))
} else {
let mut hasher = blake3::Hasher::new();
// BLAKE3 state is 64 bytes - so, we can absorb 64 bytes into the state in a single
// permutation. we move the elements into the hasher via the buffer to give the CPU
// a chance to process multiple element-to-byte conversions in parallel
let mut buf = [0_u8; 64];
let mut chunk_iter = E::slice_as_base_elements(elements).chunks_exact(8);
for chunk in chunk_iter.by_ref() {
for i in 0..8 {
buf[i * 8..(i + 1) * 8].copy_from_slice(&chunk[i].as_int().to_le_bytes());
}
hasher.update(&buf);
}
for element in chunk_iter.remainder() {
hasher.update(&element.as_int().to_le_bytes());
}
hasher.finalize()
};
*shrink_bytes(&digest.into())
}
/// Owned bytes expansion.
fn expand_bytes<const M: usize, const N: usize>(bytes: &[u8; M]) -> [u8; N] {
// compile-time assertion
assert!(M <= N, "M should fit in N so M can be expanded!");
// this branch is constant so it will be optimized to be either one of the variants in release
// mode
if M == N {
// safety: the sizes are checked to be the same
unsafe { transmute_copy(bytes) }
} else {
let mut expanded = [0u8; N];
expanded[..M].copy_from_slice(bytes);
expanded
}
}
// Cast the slice into contiguous bytes.
fn prepare_merge<const N: usize, D>(args: &[D; N]) -> &[u8]
where
D: Deref<Target = [u8]>,
{
// compile-time assertion
assert!(N > 0, "N shouldn't represent an empty slice!");
let values = args.as_ptr() as *const u8;
let len = size_of::<D>() * N;
// safety: the values are tested to be contiguous
let bytes = unsafe { from_raw_parts(values, len) };
debug_assert_eq!(args[0].deref(), &bytes[..len / N]);
bytes
}

View file

@ -1,49 +0,0 @@
use alloc::vec::Vec;
use proptest::prelude::*;
use rand_utils::rand_vector;
use super::*;
#[test]
fn blake3_hash_elements() {
// test multiple of 8
let elements = rand_vector::<Felt>(16);
let expected = compute_expected_element_hash(&elements);
let actual: [u8; 32] = hash_elements(&elements);
assert_eq!(&expected, &actual);
// test not multiple of 8
let elements = rand_vector::<Felt>(17);
let expected = compute_expected_element_hash(&elements);
let actual: [u8; 32] = hash_elements(&elements);
assert_eq!(&expected, &actual);
}
proptest! {
#[test]
fn blake160_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
Blake3_160::hash(vec);
}
#[test]
fn blake192_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
Blake3_192::hash(vec);
}
#[test]
fn blake256_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
Blake3_256::hash(vec);
}
}
// HELPER FUNCTIONS
// ================================================================================================
fn compute_expected_element_hash(elements: &[Felt]) -> blake3::Hash {
let mut bytes = Vec::new();
for element in elements.iter() {
bytes.extend_from_slice(&element.as_int().to_le_bytes());
}
blake3::hash(&bytes)
}

View file

@ -1,19 +0,0 @@
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
use super::{CubeExtension, Felt, FieldElement, StarkField, ZERO};
pub mod blake;
mod rescue;
pub mod rpo {
pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError};
}
pub mod rpx {
pub use super::rescue::{Rpx256, RpxDigest, RpxDigestError};
}
// RE-EXPORTS
// ================================================================================================
pub use winter_crypto::{Digest, ElementHasher, Hasher};

View file

@ -1,101 +0,0 @@
#[cfg(target_feature = "sve")]
pub mod optimized {
use crate::{hash::rescue::STATE_WIDTH, Felt};
mod ffi {
#[link(name = "rpo_sve", kind = "static")]
extern "C" {
pub fn add_constants_and_apply_sbox(
state: *mut std::ffi::c_ulong,
constants: *const std::ffi::c_ulong,
) -> bool;
pub fn add_constants_and_apply_inv_sbox(
state: *mut std::ffi::c_ulong,
constants: *const std::ffi::c_ulong,
) -> bool;
}
}
#[inline(always)]
pub fn add_constants_and_apply_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
unsafe {
ffi::add_constants_and_apply_sbox(
state.as_mut_ptr() as *mut u64,
ark.as_ptr() as *const u64,
)
}
}
#[inline(always)]
pub fn add_constants_and_apply_inv_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
unsafe {
ffi::add_constants_and_apply_inv_sbox(
state.as_mut_ptr() as *mut u64,
ark.as_ptr() as *const u64,
)
}
}
}
#[cfg(target_feature = "avx2")]
mod x86_64_avx2;
#[cfg(target_feature = "avx2")]
pub mod optimized {
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
use crate::{
hash::rescue::{add_constants, STATE_WIDTH},
Felt,
};
#[inline(always)]
pub fn add_constants_and_apply_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
add_constants(state, ark);
unsafe {
apply_sbox(std::mem::transmute(state));
}
true
}
#[inline(always)]
pub fn add_constants_and_apply_inv_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
add_constants(state, ark);
unsafe {
apply_inv_sbox(std::mem::transmute(state));
}
true
}
}
#[cfg(not(any(target_feature = "avx2", target_feature = "sve")))]
pub mod optimized {
use crate::{hash::rescue::STATE_WIDTH, Felt};
#[inline(always)]
pub fn add_constants_and_apply_sbox(
_state: &mut [Felt; STATE_WIDTH],
_ark: &[Felt; STATE_WIDTH],
) -> bool {
false
}
#[inline(always)]
pub fn add_constants_and_apply_inv_sbox(
_state: &mut [Felt; STATE_WIDTH],
_ark: &[Felt; STATE_WIDTH],
) -> bool {
false
}
}

View file

@ -1,328 +0,0 @@
use core::arch::x86_64::*;
// The following AVX2 implementation has been copied from plonky2:
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
// Preliminary notes:
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily emulated.
// The method recognizes that for a + b overflowed iff (a + b) < a:
// 1. res_lo = a_lo + b_lo
// 2. carry_mask = res_lo < a_lo
// 3. res_hi = a_hi + b_hi - carry_mask
//
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
// return -1 (all bits 1) for true and 0 for false.
//
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts 1
// << 63 to enable this trick. Addition with carry example:
// 1. a_lo_s = shift(a_lo)
// 2. res_lo_s = a_lo_s + b_lo
// 3. carry_mask = res_lo_s <s a_lo_s
// 4. res_lo = shift(res_lo_s)
// 5. res_hi = a_hi + b_hi - carry_mask
//
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition
// is shifted if exactly one of the operands is shifted, as is the case on
// line 2. Line 3. performs a signed comparison res_lo_s <s a_lo_s on shifted values to
// emulate unsigned comparison res_lo <u a_lo on unshifted values. Finally, line 4. reverses the
// shift so the result can be returned.
//
// When performing a chain of calculations, we can often save instructions by letting
// the shift propagate through and only undoing it when necessary.
// For example, to compute the addition of three two-word (128-bit) numbers we can do:
// 1. a_lo_s = shift(a_lo)
// 2. tmp_lo_s = a_lo_s + b_lo
// 3. tmp_carry_mask = tmp_lo_s <s a_lo_s
// 4. tmp_hi = a_hi + b_hi - tmp_carry_mask
// 5. res_lo_s = tmp_lo_s + c_lo vi. res_carry_mask = res_lo_s <s tmp_lo_s
// 6. res_carry_mask = res_lo_s <s tmp_lo_s
// 7. res_lo = shift(res_lo_s)
// 8. res_hi = tmp_hi + c_hi - res_carry_mask
//
// Notice that the above 3-value addition still only requires two calls to shift, just like our
// 2-value addition.
#[inline(always)]
pub fn branch_hint() {
// NOTE: These are the currently supported assembly architectures. See the
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
// the most up-to-date list.
#[cfg(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "riscv32",
target_arch = "riscv64",
target_arch = "x86",
target_arch = "x86_64",
))]
unsafe {
core::arch::asm!("", options(nomem, nostack, preserves_flags));
}
}
macro_rules! map3 {
($f:ident:: < $l:literal > , $v:ident) => {
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
};
($f:ident:: < $l:literal > , $v1:ident, $v2:ident) => {
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
};
($f:ident, $v:ident) => {
($f($v.0), $f($v.1), $f($v.2))
};
($f:ident, $v0:ident, $v1:ident) => {
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
};
($f:ident,rep $v0:ident, $v1:ident) => {
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
};
($f:ident, $v0:ident,rep $v1:ident) => {
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
};
}
#[inline(always)]
unsafe fn square3(
x: (__m256i, __m256i, __m256i),
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
let x_hi = {
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
// This is safe and free.
let x_ps = map3!(_mm256_castsi256_ps, x);
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
map3!(_mm256_castps_si256, x_hi_ps)
};
// All pairwise multiplications.
let mul_ll = map3!(_mm256_mul_epu32, x, x);
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
// position).
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
(res_lo, res_hi)
}
#[inline(always)]
unsafe fn mul3(
x: (__m256i, __m256i, __m256i),
y: (__m256i, __m256i, __m256i),
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
let epsilon = _mm256_set1_epi64x(0xffffffff);
let x_hi = {
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
// This is safe and free.
let x_ps = map3!(_mm256_castsi256_ps, x);
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
map3!(_mm256_castps_si256, x_hi_ps)
};
let y_hi = {
let y_ps = map3!(_mm256_castsi256_ps, y);
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
map3!(_mm256_castps_si256, y_hi_ps)
};
// All four pairwise multiplications
let mul_ll = map3!(_mm256_mul_epu32, x, y);
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
// Bignum addition
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
// Also, extract high 32 bits of t0 and add to mul_hh.
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
// Lastly, extract the high 32 bits of t1 and add to t2.
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
// position).
let t1_lo = {
let t1_ps = map3!(_mm256_castsi256_ps, t1);
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
map3!(_mm256_castps_si256, t1_lo_ps)
};
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
(res_lo, res_hi)
}
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
#[inline(always)]
unsafe fn add_small(
x_s: (__m256i, __m256i, __m256i),
y: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
res_s
}
#[inline(always)]
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
// The subtraction is very unlikely to overflow so we're best off branching.
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
// floating-point (this is free).
let mask_pd = _mm256_castsi256_pd(mask);
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
// did not occur for any of the vector elements.
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
res_wrapped_s
} else {
branch_hint();
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
_mm256_sub_epi64(res_wrapped_s, adj_amount)
}
}
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
#[inline(always)]
unsafe fn sub_tiny(
x_s: (__m256i, __m256i, __m256i),
y: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
res_s
}
#[inline(always)]
unsafe fn reduce3(
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
) -> (__m256i, __m256i, __m256i) {
let sign_bit = _mm256_set1_epi64x(i64::MIN);
let epsilon = _mm256_set1_epi64x(0xffffffff);
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
let lo1_s = sub_tiny(lo0_s, hi_hi0);
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
let lo2_s = add_small(lo1_s, t1);
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
lo2
}
#[inline(always)]
unsafe fn mul_reduce(
a: (__m256i, __m256i, __m256i),
b: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
reduce3(mul3(a, b))
}
#[inline(always)]
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
reduce3(square3(state))
}
#[inline(always)]
unsafe fn exp_acc(
high: (__m256i, __m256i, __m256i),
low: (__m256i, __m256i, __m256i),
exp: usize,
) -> (__m256i, __m256i, __m256i) {
let mut result = high;
for _ in 0..exp {
result = square_reduce(result);
}
mul_reduce(result, low)
}
#[inline(always)]
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
let state2 = square_reduce(state);
let state4_unreduced = square3(state2);
let state3_unreduced = mul3(state2, state);
let state4 = reduce3(state4_unreduced);
let state3 = reduce3(state3_unreduced);
let state7_unreduced = mul3(state3, state4);
let state7 = reduce3(state7_unreduced);
state7
}
#[inline(always)]
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
// compute base^10540996611094048183 using 72 multiplications per array element
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
// compute base^10
let t1 = square_reduce(state);
// compute base^100
let t2 = square_reduce(t1);
// compute base^100100
let t3 = exp_acc(t2, t2, 3);
// compute base^100100100100
let t4 = exp_acc(t3, t3, 6);
// compute base^100100100100100100100100
let t5 = exp_acc(t4, t4, 12);
// compute base^100100100100100100100100100100
let t6 = exp_acc(t5, t3, 6);
// compute base^1001001001001001001001001001000100100100100100100100100100100
let t7 = exp_acc(t6, t6, 31);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
let b = mul_reduce(t1, mul_reduce(t2, state));
mul_reduce(a, b)
}
#[inline(always)]
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
(
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
)
}
#[inline(always)]
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
}
#[inline(always)]
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
let mut state = avx2_load(&buffer);
state = do_apply_sbox(state);
avx2_store(buffer, state);
}
#[inline(always)]
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
let mut state = avx2_load(&buffer);
state = do_apply_inv_sbox(state);
avx2_store(buffer, state);
}

View file

@ -1,197 +0,0 @@
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
// ================================================================================================
//! This module contains helper functions as well as constants used to perform the vector-matrix
//! multiplication step of the Rescue prime permutation. The special form of our MDS matrix
//! i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
//! of two vectors in "frequency domain". This follows from the simple fact that every circulant
//! matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
//! The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
//! with explicit expressions. It also avoids, due to the form of our matrix in the frequency
//! domain, divisions by 2 and repeated modular reductions. This is because of our explicit choice
//! of an MDS matrix that has small powers of 2 entries in frequency domain.
//! The following implementation has benefited greatly from the discussions and insights of
//! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
//! implementation.
// Rescue MDS matrix in frequency domain.
//
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
// and application of the final four 3-point FFT in order to get the full 12-point FFT.
// The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4.
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to
// generate MDS matrices efficiently in original domain, that was developed by the Polygon Zero
// team.
const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16];
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)];
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
#[inline(always)]
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
let (u4, u5, u6) = fft4_real([s1, s4, s7, s10]);
let (u8, u9, u10) = fft4_real([s2, s5, s8, s11]);
// This where the multiplication in frequency domain is done. More precisely, and with
// the appropriate permutations in between, the sequence of
// 3-point FFTs --> multiplication by twiddle factors --> Hadamard multiplication -->
// 3 point iFFTs --> multiplication by (inverse) twiddle factors
// is "squashed" into one step composed of the functions "block1", "block2" and "block3".
// The expressions in the aforementioned functions are the result of explicit computations
// combined with the Karatsuba trick for the multiplication of Complex numbers.
let [v0, v4, v8] = block1([u0, u4, u8], MDS_FREQ_BLOCK_ONE);
let [v1, v5, v9] = block2([u1, u5, u9], MDS_FREQ_BLOCK_TWO);
let [v2, v6, v10] = block3([u2, u6, u10], MDS_FREQ_BLOCK_THREE);
// The 4th block is not computed as it is similar to the 2nd one, up to complex conjugation,
// and is, due to the use of the real FFT and iFFT, redundant.
let [s0, s3, s6, s9] = ifft4_real((v0, v1, v2));
let [s1, s4, s7, s10] = ifft4_real((v4, v5, v6));
let [s2, s5, s8, s11] = ifft4_real((v8, v9, v10));
[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
}
// We use the real FFT to avoid redundant computations. See https://www.mdpi.com/2076-3417/12/9/4700
#[inline(always)]
const fn fft2_real(x: [u64; 2]) -> [i64; 2] {
[(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)]
}
#[inline(always)]
const fn ifft2_real(y: [i64; 2]) -> [u64; 2] {
// We avoid divisions by 2 by appropriately scaling the MDS matrix constants.
[(y[0] + y[1]) as u64, (y[0] - y[1]) as u64]
}
#[inline(always)]
const fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
let [z0, z2] = fft2_real([x[0], x[2]]);
let [z1, z3] = fft2_real([x[1], x[3]]);
let y0 = z0 + z1;
let y1 = (z2, -z3);
let y2 = z0 - z1;
(y0, y1, y2)
}
#[inline(always)]
const fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] {
// In calculating 'z0' and 'z1', division by 2 is avoided by appropriately scaling
// the MDS matrix constants.
let z0 = y.0 + y.2;
let z1 = y.0 - y.2;
let z2 = y.1 .0;
let z3 = -y.1 .1;
let [x0, x2] = ifft2_real([z0, z2]);
let [x1, x3] = ifft2_real([z1, z3]);
[x0, x1, x2, x3]
}
#[inline(always)]
const fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 + x1 * y2 + x2 * y1;
let z1 = x0 * y1 + x1 * y0 + x2 * y2;
let z2 = x0 * y2 + x1 * y1 + x2 * y0;
[z0, z1, z2]
}
#[inline(always)]
const fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x;
let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y;
let x0s = x0r + x0i;
let x1s = x1r + x1i;
let x2s = x2r + x2i;
let y0s = y0r + y0i;
let y1s = y1r + y1i;
let y2s = y2r + y2i;
// Compute x0y0 ix1y2 ix2y1 using Karatsuba for complex numbers multiplication
let m0 = (x0r * y0r, x0i * y0i);
let m1 = (x1r * y2r, x1i * y2i);
let m2 = (x2r * y1r, x2i * y1i);
let z0r = (m0.0 - m0.1) + (x1s * y2s - m1.0 - m1.1) + (x2s * y1s - m2.0 - m2.1);
let z0i = (x0s * y0s - m0.0 - m0.1) + (-m1.0 + m1.1) + (-m2.0 + m2.1);
let z0 = (z0r, z0i);
// Compute x0y1 + x1y0 ix2y2 using Karatsuba for complex numbers multiplication
let m0 = (x0r * y1r, x0i * y1i);
let m1 = (x1r * y0r, x1i * y0i);
let m2 = (x2r * y2r, x2i * y2i);
let z1r = (m0.0 - m0.1) + (m1.0 - m1.1) + (x2s * y2s - m2.0 - m2.1);
let z1i = (x0s * y1s - m0.0 - m0.1) + (x1s * y0s - m1.0 - m1.1) + (-m2.0 + m2.1);
let z1 = (z1r, z1i);
// Compute x0y2 + x1y1 + x2y0 using Karatsuba for complex numbers multiplication
let m0 = (x0r * y2r, x0i * y2i);
let m1 = (x1r * y1r, x1i * y1i);
let m2 = (x2r * y0r, x2i * y0i);
let z2r = (m0.0 - m0.1) + (m1.0 - m1.1) + (m2.0 - m2.1);
let z2i = (x0s * y2s - m0.0 - m0.1) + (x1s * y1s - m1.0 - m1.1) + (x2s * y0s - m2.0 - m2.1);
let z2 = (z2r, z2i);
[z0, z1, z2]
}
#[inline(always)]
const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
let [x0, x1, x2] = x;
let [y0, y1, y2] = y;
let z0 = x0 * y0 - x1 * y2 - x2 * y1;
let z1 = x0 * y1 + x1 * y0 - x2 * y2;
let z2 = x0 * y2 + x1 * y1 + x2 * y0;
[z0, z1, z2]
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::super::{apply_mds, Felt, MDS, ZERO};
const STATE_WIDTH: usize = 12;
#[inline(always)]
fn apply_mds_naive(state: &mut [Felt; STATE_WIDTH]) {
let mut result = [ZERO; STATE_WIDTH];
result.iter_mut().zip(MDS).for_each(|(r, mds_row)| {
state.iter().zip(mds_row).for_each(|(&s, m)| {
*r += m * s;
});
});
*state = result;
}
proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {
let mut v1 = [ZERO; STATE_WIDTH];
let mut v2;
for i in 0..STATE_WIDTH {
v1[i] = Felt::new(a[i]);
}
v2 = v1;
apply_mds_naive(&mut v1);
apply_mds(&mut v2);
prop_assert_eq!(v1, v2);
}
}
}

View file

@ -1,214 +0,0 @@
use super::{Felt, STATE_WIDTH, ZERO};
mod freq;
pub use freq::mds_multiply_freq;
// MDS MULTIPLICATION
// ================================================================================================
#[inline(always)]
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
let mut result = [ZERO; STATE_WIDTH];
// Using the linearity of the operations we can split the state into a low||high decomposition
// and operate on each with no overflow and then combine/reduce the result to a field element.
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
// frequency domain.
let mut state_l = [0u64; STATE_WIDTH];
let mut state_h = [0u64; STATE_WIDTH];
for r in 0..STATE_WIDTH {
let s = state[r].inner();
state_h[r] = s >> 32;
state_l[r] = (s as u32) as u64;
}
let state_h = mds_multiply_freq(state_h);
let state_l = mds_multiply_freq(state_l);
for r in 0..STATE_WIDTH {
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
let s_hi = (s >> 64) as u64;
let s_lo = s as u64;
let z = (s_hi << 32) - s_hi;
let (res, over) = s_lo.overflowing_add(z);
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
}
*state = result;
}
// MDS MATRIX
// ================================================================================================
/// RPO MDS matrix
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
[
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
],
[
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
],
[
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
],
[
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
],
[
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
],
[
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
],
[
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
],
[
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
],
[
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
],
[
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
],
[
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
],
[
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
],
];

View file

@ -1,347 +0,0 @@
use core::ops::Range;
use super::{CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ZERO};
mod arch;
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
mod mds;
use mds::{apply_mds, MDS};
mod rpo;
pub use rpo::{Rpo256, RpoDigest, RpoDigestError};
mod rpx;
pub use rpx::{Rpx256, RpxDigest, RpxDigestError};
#[cfg(test)]
mod tests;
// CONSTANTS
// ================================================================================================
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
/// RPX hash function, there are 3 different types of rounds.
const NUM_ROUNDS: usize = 7;
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
const STATE_WIDTH: usize = 12;
/// The rate portion of the state is located in elements 4 through 11.
const RATE_RANGE: Range<usize> = 4..12;
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
const INPUT1_RANGE: Range<usize> = 4..8;
const INPUT2_RANGE: Range<usize> = 8..12;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
const CAPACITY_RANGE: Range<usize> = 0..4;
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
///
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
/// rate portion).
const DIGEST_RANGE: Range<usize> = 4..8;
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
/// The number of bytes needed to encoded a digest
const DIGEST_BYTES: usize = 32;
/// The number of byte chunks defining a field element when hashing a sequence of bytes
const BINARY_CHUNK_SIZE: usize = 7;
/// S-Box and Inverse S-Box powers;
///
/// The constants are defined for tests only because the exponentiations in the code are unrolled
/// for efficiency reasons.
#[cfg(test)]
const ALPHA: u64 = 7;
#[cfg(test)]
const INV_ALPHA: u64 = 10540996611094048183;
// SBOX FUNCTION
// ================================================================================================
#[inline(always)]
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
state[0] = state[0].exp7();
state[1] = state[1].exp7();
state[2] = state[2].exp7();
state[3] = state[3].exp7();
state[4] = state[4].exp7();
state[5] = state[5].exp7();
state[6] = state[6].exp7();
state[7] = state[7].exp7();
state[8] = state[8].exp7();
state[9] = state[9].exp7();
state[10] = state[10].exp7();
state[11] = state[11].exp7();
}
// INVERSE SBOX FUNCTION
// ================================================================================================
#[inline(always)]
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
// compute base^10540996611094048183 using 72 multiplications per array element
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
// compute base^10
let mut t1 = *state;
t1.iter_mut().for_each(|t| *t = t.square());
// compute base^100
let mut t2 = t1;
t2.iter_mut().for_each(|t| *t = t.square());
// compute base^100100
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
// compute base^100100100100
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
// compute base^100100100100100100100100
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
// compute base^100100100100100100100100100100
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
// compute base^1001001001001001001001001001000100100100100100100100100100100
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
for (i, s) in state.iter_mut().enumerate() {
let a = (t7[i].square() * t6[i]).square().square();
let b = t1[i] * t2[i] * *s;
*s = a * b;
}
#[inline(always)]
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
base: [B; N],
tail: [B; N],
) -> [B; N] {
let mut result = base;
for _ in 0..M {
result.iter_mut().for_each(|r| *r = r.square());
}
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
result
}
}
#[inline(always)]
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
}
// ROUND CONSTANTS
// ================================================================================================
/// Rescue round constants;
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
///
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[
Felt::new(5789762306288267392),
Felt::new(6522564764413701783),
Felt::new(17809893479458208203),
Felt::new(107145243989736508),
Felt::new(6388978042437517382),
Felt::new(15844067734406016715),
Felt::new(9975000513555218239),
Felt::new(3344984123768313364),
Felt::new(9959189626657347191),
Felt::new(12960773468763563665),
Felt::new(9602914297752488475),
Felt::new(16657542370200465908),
],
[
Felt::new(12987190162843096997),
Felt::new(653957632802705281),
Felt::new(4441654670647621225),
Felt::new(4038207883745915761),
Felt::new(5613464648874830118),
Felt::new(13222989726778338773),
Felt::new(3037761201230264149),
Felt::new(16683759727265180203),
Felt::new(8337364536491240715),
Felt::new(3227397518293416448),
Felt::new(8110510111539674682),
Felt::new(2872078294163232137),
],
[
Felt::new(18072785500942327487),
Felt::new(6200974112677013481),
Felt::new(17682092219085884187),
Felt::new(10599526828986756440),
Felt::new(975003873302957338),
Felt::new(8264241093196931281),
Felt::new(10065763900435475170),
Felt::new(2181131744534710197),
Felt::new(6317303992309418647),
Felt::new(1401440938888741532),
Felt::new(8884468225181997494),
Felt::new(13066900325715521532),
],
[
Felt::new(5674685213610121970),
Felt::new(5759084860419474071),
Felt::new(13943282657648897737),
Felt::new(1352748651966375394),
Felt::new(17110913224029905221),
Felt::new(1003883795902368422),
Felt::new(4141870621881018291),
Felt::new(8121410972417424656),
Felt::new(14300518605864919529),
Felt::new(13712227150607670181),
Felt::new(17021852944633065291),
Felt::new(6252096473787587650),
],
[
Felt::new(4887609836208846458),
Felt::new(3027115137917284492),
Felt::new(9595098600469470675),
Felt::new(10528569829048484079),
Felt::new(7864689113198939815),
Felt::new(17533723827845969040),
Felt::new(5781638039037710951),
Felt::new(17024078752430719006),
Felt::new(109659393484013511),
Felt::new(7158933660534805869),
Felt::new(2955076958026921730),
Felt::new(7433723648458773977),
],
[
Felt::new(16308865189192447297),
Felt::new(11977192855656444890),
Felt::new(12532242556065780287),
Felt::new(14594890931430968898),
Felt::new(7291784239689209784),
Felt::new(5514718540551361949),
Felt::new(10025733853830934803),
Felt::new(7293794580341021693),
Felt::new(6728552937464861756),
Felt::new(6332385040983343262),
Felt::new(13277683694236792804),
Felt::new(2600778905124452676),
],
[
Felt::new(7123075680859040534),
Felt::new(1034205548717903090),
Felt::new(7717824418247931797),
Felt::new(3019070937878604058),
Felt::new(11403792746066867460),
Felt::new(10280580802233112374),
Felt::new(337153209462421218),
Felt::new(13333398568519923717),
Felt::new(3596153696935337464),
Felt::new(8104208463525993784),
Felt::new(14345062289456085693),
Felt::new(17036731477169661256),
],
];
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[
Felt::new(6077062762357204287),
Felt::new(15277620170502011191),
Felt::new(5358738125714196705),
Felt::new(14233283787297595718),
Felt::new(13792579614346651365),
Felt::new(11614812331536767105),
Felt::new(14871063686742261166),
Felt::new(10148237148793043499),
Felt::new(4457428952329675767),
Felt::new(15590786458219172475),
Felt::new(10063319113072092615),
Felt::new(14200078843431360086),
],
[
Felt::new(6202948458916099932),
Felt::new(17690140365333231091),
Felt::new(3595001575307484651),
Felt::new(373995945117666487),
Felt::new(1235734395091296013),
Felt::new(14172757457833931602),
Felt::new(707573103686350224),
Felt::new(15453217512188187135),
Felt::new(219777875004506018),
Felt::new(17876696346199469008),
Felt::new(17731621626449383378),
Felt::new(2897136237748376248),
],
[
Felt::new(8023374565629191455),
Felt::new(15013690343205953430),
Felt::new(4485500052507912973),
Felt::new(12489737547229155153),
Felt::new(9500452585969030576),
Felt::new(2054001340201038870),
Felt::new(12420704059284934186),
Felt::new(355990932618543755),
Felt::new(9071225051243523860),
Felt::new(12766199826003448536),
Felt::new(9045979173463556963),
Felt::new(12934431667190679898),
],
[
Felt::new(18389244934624494276),
Felt::new(16731736864863925227),
Felt::new(4440209734760478192),
Felt::new(17208448209698888938),
Felt::new(8739495587021565984),
Felt::new(17000774922218161967),
Felt::new(13533282547195532087),
Felt::new(525402848358706231),
Felt::new(16987541523062161972),
Felt::new(5466806524462797102),
Felt::new(14512769585918244983),
Felt::new(10973956031244051118),
],
[
Felt::new(6982293561042362913),
Felt::new(14065426295947720331),
Felt::new(16451845770444974180),
Felt::new(7139138592091306727),
Felt::new(9012006439959783127),
Felt::new(14619614108529063361),
Felt::new(1394813199588124371),
Felt::new(4635111139507788575),
Felt::new(16217473952264203365),
Felt::new(10782018226466330683),
Felt::new(6844229992533662050),
Felt::new(7446486531695178711),
],
[
Felt::new(3736792340494631448),
Felt::new(577852220195055341),
Felt::new(6689998335515779805),
Felt::new(13886063479078013492),
Felt::new(14358505101923202168),
Felt::new(7744142531772274164),
Felt::new(16135070735728404443),
Felt::new(12290902521256031137),
Felt::new(12059913662657709804),
Felt::new(16456018495793751911),
Felt::new(4571485474751953524),
Felt::new(17200392109565783176),
],
[
Felt::new(17130398059294018733),
Felt::new(519782857322261988),
Felt::new(9625384390925085478),
Felt::new(1664893052631119222),
Felt::new(7629576092524553570),
Felt::new(3485239601103661425),
Felt::new(9755891797164033838),
Felt::new(15218148195153269027),
Felt::new(16460604813734957368),
Felt::new(9643968136937729763),
Felt::new(3611348709641382851),
Felt::new(18256379591337759196),
],
];

View file

@ -1,646 +0,0 @@
use alloc::string::String;
use core::{
cmp::Ordering,
fmt::Display,
hash::{Hash, Hasher},
ops::Deref,
slice,
};
use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
use crate::{
rand::Randomizable,
utils::{
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
},
};
// DIGEST TRAIT IMPLEMENTATIONS
// ================================================================================================
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct RpoDigest([Felt; DIGEST_SIZE]);
impl RpoDigest {
/// The serialized size of the digest in bytes.
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
pub fn as_elements(&self) -> &[Felt] {
self.as_ref()
}
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
<Self as Digest>::as_bytes(self)
}
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where
I: Iterator<Item = &'a Self>,
{
digests.flat_map(|d| d.0.iter())
}
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
}
/// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes())
}
}
impl Hash for RpoDigest {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.as_bytes());
}
}
impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
result
}
}
impl Deref for RpoDigest {
type Target = [Felt; DIGEST_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Ord for RpoDigest {
fn cmp(&self, other: &Self) -> Ordering {
// compare the inner u64 of both elements.
//
// it will iterate the elements and will return the first computation different than
// `Equal`. Otherwise, the ordering is equal.
//
// the endianness is irrelevant here because since, this being a cryptographically secure
// hash computation, the digest shouldn't have any ordered property of its input.
//
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
// montgomery reduction for every limb. that is safe because every inner element of the
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
Ordering::Equal,
|ord, (a, b)| match ord {
Ordering::Equal => a.cmp(&b),
_ => ord,
},
)
}
}
impl PartialOrd for RpoDigest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Display for RpoDigest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let encoded: String = self.into();
write!(f, "{}", encoded)?;
Ok(())
}
}
impl Randomizable for RpoDigest {
const VALUE_SIZE: usize = DIGEST_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
if let Some(bytes_array) = bytes_array {
Self::try_from(bytes_array).ok()
} else {
None
}
}
}
// CONVERSIONS: FROM RPO DIGEST
// ================================================================================================
#[derive(Debug, Error)]
pub enum RpoDigestError {
#[error("failed to convert digest field element to {0}")]
TypeConversion(&'static str),
#[error("failed to convert to field element: {0}")]
InvalidFieldElement(String),
}
impl TryFrom<&RpoDigest> for [bool; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [bool; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
fn to_bool(v: u64) -> Option<bool> {
if v <= 1 {
Some(v == 1)
} else {
None
}
}
Ok([
to_bool(value.0[0].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[1].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[2].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[3].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u8; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u8; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u16; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u16; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u32; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u32; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
])
}
}
impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
[
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
}
}
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
value.0
}
}
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: RpoDigest) -> Self {
value.as_bytes()
}
}
impl From<&RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpoDigest) -> Self {
value.to_hex()
}
}
// CONVERSIONS: TO RPO DIGEST
// ================================================================================================
impl From<&[bool; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[bool; DIGEST_SIZE]> for RpoDigest {
fn from(value: [bool; DIGEST_SIZE]) -> Self {
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
}
}
impl From<&[u8; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u8; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u8; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u16; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u16; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u16; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u32; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u32; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u32; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
(*value).try_into()
}
}
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
Ok(Self([
value[0].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[1].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[2].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[3].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
]))
}
}
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
Self(*value)
}
}
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
}
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
// Note: the input length is known, the conversion from slice to array must succeed so the
// `unwrap`s below are safe
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
return Err(HexParseError::OutOfRange);
}
Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
}
}
impl TryFrom<&[u8]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<&str> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpoDigest::try_from)
}
}
impl TryFrom<String> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl TryFrom<&String> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for RpoDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes());
}
fn get_size_hint(&self) -> usize {
Self::SERIALIZED_SIZE
}
}
impl Deserializable for RpoDigest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
for inner in inner.iter_mut() {
let e = source.read_u64()?;
if e >= Felt::MODULUS {
return Err(DeserializationError::InvalidValue(String::from(
"Value not in the appropriate range",
)));
}
*inner = Felt::new(e);
}
Ok(Self(inner))
}
}
// ITERATORS
// ================================================================================================
impl IntoIterator for RpoDigest {
type Item = Felt;
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use alloc::string::String;
use rand_utils::rand_value;
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
use crate::utils::SliceReader;
#[test]
fn digest_serialization() {
let e1 = Felt::new(rand_value());
let e2 = Felt::new(rand_value());
let e3 = Felt::new(rand_value());
let e4 = Felt::new(rand_value());
let d1 = RpoDigest([e1, e2, e3, e4]);
let mut bytes = vec![];
d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len());
assert_eq!(bytes.len(), d1.get_size_hint());
let mut reader = SliceReader::new(&bytes);
let d2 = RpoDigest::read_from(&mut reader).unwrap();
assert_eq!(d1, d2);
}
#[test]
fn digest_encoding() {
let digest = RpoDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
let string: String = digest.into();
let round_trip: RpoDigest = string.try_into().expect("decoding failed");
assert_eq!(digest, round_trip);
}
#[test]
fn test_conversions() {
let digest = RpoDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
// BY VALUE
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpoDigest = v.into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u64; DIGEST_SIZE] = digest.into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = digest.into();
let v2: RpoDigest = v.into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: String = digest.into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
// BY REF
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u64; DIGEST_SIZE] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = (&digest).into();
let v2: RpoDigest = (&v).into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: String = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
}
}

View file

@ -1,339 +0,0 @@
use core::ops::Range;
use super::{
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
};
mod digest;
pub use digest::{RpoDigest, RpoDigestError};
#[cfg(test)]
mod tests;
// HASHER IMPLEMENTATION
// ================================================================================================
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
///
/// The hash function is implemented according to the Rescue Prime Optimized
/// [specifications](https://eprint.iacr.org/2022/1577) while the padding rule follows the one
/// described [here](https://eprint.iacr.org/2023/1045).
///
/// The parameters used to instantiate the function are:
/// * Field: 64-bit prime field with modulus p = 2^64 - 2^32 + 1.
/// * State width: 12 field elements.
/// * Rate size: r = 8 field elements.
/// * Capacity size: c = 4 field elements.
/// * Number of founds: 7.
/// * S-Box degree: 7.
///
/// The above parameters target a 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
/// a hash for the same set of elements using these functions will always produce the same
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
/// same result as hashing 8 elements which make up these digests using
/// [hash_elements()](Rpo256::hash_elements) function.
///
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
/// For example, if we take two field elements, serialize them to bytes and hash them using
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
/// deserialization procedure used by this function is different from the procedure used to
/// deserialize valid field elements.
///
/// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpo256::hash_elements) function rather than hashing the serialized bytes
/// using [hash()](Rpo256::hash) function.
///
/// ## Domain separation
/// [merge_in_domain()](Rpo256::merge_in_domain) hashes two digests into one digest with some domain
/// identifier and the current implementation sets the second capacity element to the value of
/// this domain identifier. Using a similar argument to the one formulated for domain separation of
/// the RPX hash function in Appendix C of its [specification](https://eprint.iacr.org/2023/1045),
/// one sees that doing so degrades only pre-image resistance, from its initial bound of c.log_2(p),
/// by as much as the log_2 of the size of the domain identifier space. Since pre-image resistance
/// becomes the bottleneck for the security bound of the sponge in overwrite-mode only when it is
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
/// the size of the domain identifier space, including for padding, is less than 2^128.
///
/// ## Hashing of empty input
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
/// the benefit of requiring no calls to the RPO permutation when hashing empty input.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpo256();
impl Hasher for Rpo256 {
/// Rpo256 collision resistance is 128-bits.
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpoDigest;
fn hash(bytes: &[u8]) -> Self::Digest {
// initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH];
// determine the number of field elements needed to encode `bytes` when each field element
// represents at most 7 bytes.
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
// this to achieve:
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
state[CAPACITY_RANGE.start] =
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
// initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8];
// iterate the chunks of bytes, creating a field element from each chunk and copying it
// into the state.
//
// every time the rate range is filled, a permutation is performed. if the final value of
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
// and an additional permutation must be performed.
let mut current_chunk_idx = 0_usize;
// handle the case of an empty `bytes`
let last_chunk_idx = if num_field_elem == 0 {
current_chunk_idx
} else {
num_field_elem - 1
};
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
// copy the chunk into the buffer
if current_chunk_idx != last_chunk_idx {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
// needed to fill it
buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1;
}
current_chunk_idx += 1;
// set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range.
if rate_pos == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state);
0
} else {
rate_pos + 1
}
});
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
// don't need to apply any extra padding because the first capacity element contains a
// flag indicating the number of field elements constituting the last block when the latter
// is not divisible by `RATE_WIDTH`.
if rate_pos != 0 {
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the rate as hash result.
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(Self::Digest::digests_as_elements(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element and
// set the first capacity element to 5.
// - if the value doesn't fit into a single field element, split it into two field elements,
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS {
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
} else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
}
// apply the RPO permutation and return the first four elements of the rate
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
impl ElementHasher for Rpo256 {
type BaseField = Felt;
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
// convert the elements into a list of base field elements
let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which
// is set to `elements.len() % RATE_WIDTH`.
let mut state = [ZERO; STATE_WIDTH];
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
// absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all
// elements have been absorbed
let mut i = 0;
for &element in elements.iter() {
state[RATE_RANGE.start + i] = element;
i += 1;
if i % RATE_WIDTH == 0 {
Self::apply_permutation(&mut state);
i = 0;
}
}
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
if i > 0 {
while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO;
i += 1;
}
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the state as hash result
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
// HASH FUNCTION IMPLEMENTATION
// ================================================================================================
impl Rpo256 {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The number of rounds is set to 7 to target 128-bit security level.
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
pub const STATE_WIDTH: usize = STATE_WIDTH;
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
/// MDS matrix used for computing the linear layer in a RPO round.
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
/// Round constants added to the hasher state in the first half of the RPO round.
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
/// Round constants added to the hasher state in the second half of the RPO round.
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
// TRAIT PASS-THROUGH FUNCTIONS
// --------------------------------------------------------------------------------------------
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> RpoDigest {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
<Self as ElementHasher>::hash_elements(elements)
}
// DOMAIN IDENTIFIER
// --------------------------------------------------------------------------------------------
/// Returns a hash of two digests and a domain identifier.
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = RpoDigest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// set the second capacity element to the domain value. The first capacity element is used
// for padding purposes.
state[CAPACITY_RANGE.start + 1] = domain;
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
// RESCUE PERMUTATION
// --------------------------------------------------------------------------------------------
/// Applies RPO permutation to the provided state.
#[inline(always)]
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
for i in 0..NUM_ROUNDS {
Self::apply_round(state, i);
}
}
/// RPO round function.
#[inline(always)]
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
// apply first half of RPO round
apply_mds(state);
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
add_constants(state, &ARK1[round]);
apply_sbox(state);
}
// apply second half of RPO round
apply_mds(state);
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
add_constants(state, &ARK2[round]);
apply_inv_sbox(state);
}
}
}

View file

@ -1,387 +0,0 @@
use alloc::{collections::BTreeSet, vec::Vec};
use proptest::prelude::*;
use rand_utils::rand_value;
use super::{
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, STATE_WIDTH, ZERO,
};
use crate::{
hash::rescue::{BINARY_CHUNK_SIZE, CAPACITY_RANGE, RATE_WIDTH},
Word, ONE,
};
#[test]
fn test_sbox() {
let state = [Felt::new(rand_value()); STATE_WIDTH];
let mut expected = state;
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
let mut actual = state;
apply_sbox(&mut actual);
assert_eq!(expected, actual);
}
#[test]
fn test_inv_sbox() {
let state = [Felt::new(rand_value()); STATE_WIDTH];
let mut expected = state;
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
let mut actual = state;
apply_inv_sbox(&mut actual);
assert_eq!(expected, actual);
}
#[test]
fn hash_elements_vs_merge() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpoDigest; 2] = [
RpoDigest::new(elements[..4].try_into().unwrap()),
RpoDigest::new(elements[4..].try_into().unwrap()),
];
let m_result = Rpo256::merge(&digests);
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn merge_vs_merge_in_domain() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpoDigest; 2] = [
RpoDigest::new(elements[..4].try_into().unwrap()),
RpoDigest::new(elements[4..].try_into().unwrap()),
];
let merge_result = Rpo256::merge(&digests);
// ------------- merge with domain = 0 -------------
// set domain to ZERO. This should not change the result.
let domain = ZERO;
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
assert_eq!(merge_result, merge_in_domain_result);
// ------------- merge with domain = 1 -------------
// set domain to ONE. This should change the result.
let domain = ONE;
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
assert_ne!(merge_result, merge_in_domain_result);
}
#[test]
fn hash_elements_vs_merge_with_int() {
let tmp = [Felt::new(rand_value()); 4];
let seed = RpoDigest::new(tmp);
// ----- value fits into a field element ------------------------------------------------------
let val: Felt = Felt::new(rand_value());
let m_result = Rpo256::merge_with_int(seed, val.as_int());
let mut elements = seed.as_elements().to_vec();
elements.push(val);
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(m_result, h_result);
// ----- value does not fit into a field element ----------------------------------------------
let val = Felt::MODULUS + 2;
let m_result = Rpo256::merge_with_int(seed, val);
let mut elements = seed.as_elements().to_vec();
elements.push(Felt::new(val));
elements.push(ONE);
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_padding() {
// adding a zero bytes at the end of a byte string should result in a different hash
let r1 = Rpo256::hash(&[1_u8, 2, 3]);
let r2 = Rpo256::hash(&[1_u8, 2, 3, 0]);
assert_ne!(r1, r2);
// same as above but with bigger inputs
let r1 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6]);
let r2 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
assert_ne!(r1, r2);
// same as above but with input splitting over two elements
let r1 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
let r2 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
assert_ne!(r1, r2);
// same as above but with multiple zeros
let r1 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
let r2 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
assert_ne!(r1, r2);
}
#[test]
fn hash_padding_no_extra_permutation_call() {
use crate::hash::rescue::DIGEST_RANGE;
// Implementation
let num_bytes = BINARY_CHUNK_SIZE * RATE_WIDTH;
let mut buffer = vec![0_u8; num_bytes];
*buffer.last_mut().unwrap() = 97;
let r1 = Rpo256::hash(&buffer);
// Expected
let final_chunk = [0_u8, 0, 0, 0, 0, 0, 97, 1];
let mut state = [ZERO; STATE_WIDTH];
// padding when hashing bytes
state[CAPACITY_RANGE.start] = Felt::from(RATE_WIDTH as u8);
*state.last_mut().unwrap() = Felt::new(u64::from_le_bytes(final_chunk));
Rpo256::apply_permutation(&mut state);
assert_eq!(&r1[0..4], &state[DIGEST_RANGE]);
}
#[test]
fn hash_elements_padding() {
let e1 = [Felt::new(rand_value()); 2];
let e2 = [e1[0], e1[1], ZERO];
let r1 = Rpo256::hash_elements(&e1);
let r2 = Rpo256::hash_elements(&e2);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements() {
let elements = [
ZERO,
ONE,
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::new(7),
];
let digests: [RpoDigest; 2] = [
RpoDigest::new(elements[..4].try_into().unwrap()),
RpoDigest::new(elements[4..8].try_into().unwrap()),
];
let m_result = Rpo256::merge(&digests);
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_empty() {
let elements: Vec<Felt> = vec![];
let zero_digest = RpoDigest::default();
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_empty_bytes() {
let bytes: Vec<u8> = vec![];
let zero_digest = RpoDigest::default();
let h_result = Rpo256::hash(&bytes);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_test_vectors() {
let elements = [
ZERO,
ONE,
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::new(7),
Felt::new(8),
Felt::new(9),
Felt::new(10),
Felt::new(11),
Felt::new(12),
Felt::new(13),
Felt::new(14),
Felt::new(15),
Felt::new(16),
Felt::new(17),
Felt::new(18),
];
for i in 0..elements.len() {
let expected = RpoDigest::new(EXPECTED[i]);
let result = Rpo256::hash_elements(&elements[..(i + 1)]);
assert_eq!(result, expected);
}
}
#[test]
fn sponge_bytes_with_remainder_length_wont_panic() {
// this test targets to assert that no panic will happen with the edge case of having an inputs
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
// size.
//
// this is a preliminary test to the fuzzy-stress of proptest.
Rpo256::hash(&[0; 113]);
}
#[test]
fn sponge_collision_for_wrapped_field_element() {
let a = Rpo256::hash(&[0; 8]);
let b = Rpo256::hash(&Felt::MODULUS.to_le_bytes());
assert_ne!(a, b);
}
#[test]
fn sponge_zeroes_collision() {
let mut zeroes = Vec::with_capacity(255);
let mut set = BTreeSet::new();
(0..255).for_each(|_| {
let hash = Rpo256::hash(&zeroes);
zeroes.push(0);
// panic if a collision was found
assert!(set.insert(hash));
});
}
proptest! {
#[test]
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
Rpo256::hash(bytes);
}
}
const EXPECTED: [Word; 19] = [
[
Felt::new(18126731724905382595),
Felt::new(7388557040857728717),
Felt::new(14290750514634285295),
Felt::new(7852282086160480146),
],
[
Felt::new(10139303045932500183),
Felt::new(2293916558361785533),
Felt::new(15496361415980502047),
Felt::new(17904948502382283940),
],
[
Felt::new(17457546260239634015),
Felt::new(803990662839494686),
Felt::new(10386005777401424878),
Felt::new(18168807883298448638),
],
[
Felt::new(13072499238647455740),
Felt::new(10174350003422057273),
Felt::new(9201651627651151113),
Felt::new(6872461887313298746),
],
[
Felt::new(2903803350580990546),
Felt::new(1838870750730563299),
Felt::new(4258619137315479708),
Felt::new(17334260395129062936),
],
[
Felt::new(8571221005243425262),
Felt::new(3016595589318175865),
Felt::new(13933674291329928438),
Felt::new(678640375034313072),
],
[
Felt::new(16314113978986502310),
Felt::new(14587622368743051587),
Felt::new(2808708361436818462),
Felt::new(10660517522478329440),
],
[
Felt::new(2242391899857912644),
Felt::new(12689382052053305418),
Felt::new(235236990017815546),
Felt::new(5046143039268215739),
],
[
Felt::new(5218076004221736204),
Felt::new(17169400568680971304),
Felt::new(8840075572473868990),
Felt::new(12382372614369863623),
],
[
Felt::new(9783834557155203486),
Felt::new(12317263104955018849),
Felt::new(3933748931816109604),
Felt::new(1843043029836917214),
],
[
Felt::new(14498234468286984551),
Felt::new(16837257669834682387),
Felt::new(6664141123711355107),
Felt::new(4590460158294697186),
],
[
Felt::new(4661800562479916067),
Felt::new(11794407552792839953),
Felt::new(9037742258721863712),
Felt::new(6287820818064278819),
],
[
Felt::new(7752693085194633729),
Felt::new(7379857372245835536),
Felt::new(9270229380648024178),
Felt::new(10638301488452560378),
],
[
Felt::new(11542686762698783357),
Felt::new(15570714990728449027),
Felt::new(7518801014067819501),
Felt::new(12706437751337583515),
],
[
Felt::new(9553923701032839042),
Felt::new(7281190920209838818),
Felt::new(2488477917448393955),
Felt::new(5088955350303368837),
],
[
Felt::new(4935426252518736883),
Felt::new(12584230452580950419),
Felt::new(8762518969632303998),
Felt::new(18159875708229758073),
],
[
Felt::new(12795429638314178838),
Felt::new(14360248269767567855),
Felt::new(3819563852436765058),
Felt::new(10859123583999067291),
],
[
Felt::new(2695742617679420093),
Felt::new(9151515850666059759),
Felt::new(15855828029180595485),
Felt::new(17190029785471463210),
],
[
Felt::new(13205273108219124830),
Felt::new(2524898486192849221),
Felt::new(14618764355375283547),
Felt::new(10615614265042186874),
],
];

View file

@ -1,634 +0,0 @@
use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
use crate::{
rand::Randomizable,
utils::{
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
},
};
// DIGEST TRAIT IMPLEMENTATIONS
// ================================================================================================
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct RpxDigest([Felt; DIGEST_SIZE]);
impl RpxDigest {
/// The serialized size of the digest in bytes.
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
pub fn as_elements(&self) -> &[Felt] {
self.as_ref()
}
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
<Self as Digest>::as_bytes(self)
}
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where
I: Iterator<Item = &'a Self>,
{
digests.flat_map(|d| d.0.iter())
}
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
}
/// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes())
}
}
impl Digest for RpxDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
result
}
}
impl Deref for RpxDigest {
type Target = [Felt; DIGEST_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Ord for RpxDigest {
fn cmp(&self, other: &Self) -> Ordering {
// compare the inner u64 of both elements.
//
// it will iterate the elements and will return the first computation different than
// `Equal`. Otherwise, the ordering is equal.
//
// the endianness is irrelevant here because since, this being a cryptographically secure
// hash computation, the digest shouldn't have any ordered property of its input.
//
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
// montgomery reduction for every limb. that is safe because every inner element of the
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
Ordering::Equal,
|ord, (a, b)| match ord {
Ordering::Equal => a.cmp(&b),
_ => ord,
},
)
}
}
impl PartialOrd for RpxDigest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Display for RpxDigest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let encoded: String = self.into();
write!(f, "{}", encoded)?;
Ok(())
}
}
impl Randomizable for RpxDigest {
const VALUE_SIZE: usize = DIGEST_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
if let Some(bytes_array) = bytes_array {
Self::try_from(bytes_array).ok()
} else {
None
}
}
}
// CONVERSIONS: FROM RPX DIGEST
// ================================================================================================
#[derive(Debug, Error)]
pub enum RpxDigestError {
#[error("failed to convert digest field element to {0}")]
TypeConversion(&'static str),
#[error("failed to convert to field element: {0}")]
InvalidFieldElement(String),
}
impl TryFrom<&RpxDigest> for [bool; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [bool; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
fn to_bool(v: u64) -> Option<bool> {
if v <= 1 {
Some(v == 1)
} else {
None
}
}
Ok([
to_bool(value.0[0].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[1].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[2].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[3].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u8; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u8; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u16; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u16; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u32; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u32; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
])
}
}
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpxDigest) -> Self {
(*value).into()
}
}
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
fn from(value: RpxDigest) -> Self {
[
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
}
}
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpxDigest) -> Self {
value.0
}
}
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpxDigest) -> Self {
value.0
}
}
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpxDigest) -> Self {
value.as_bytes()
}
}
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
fn from(value: RpxDigest) -> Self {
value.as_bytes()
}
}
impl From<&RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: &RpxDigest) -> Self {
(*value).into()
}
}
impl From<RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpxDigest) -> Self {
value.to_hex()
}
}
// CONVERSIONS: TO RPX DIGEST
// ================================================================================================
impl From<&[bool; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[bool; DIGEST_SIZE]> for RpxDigest {
fn from(value: [bool; DIGEST_SIZE]) -> Self {
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
}
}
impl From<&[u8; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u8; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u8; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u16; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u16; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u16; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u32; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u32; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u32; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
(*value).try_into()
}
}
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
Ok(Self([
value[0].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[1].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[2].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[3].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
]))
}
}
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
Self(*value)
}
}
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
}
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
// Note: the input length is known, the conversion from slice to array must succeed so the
// `unwrap`s below are safe
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
return Err(HexParseError::OutOfRange);
}
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
}
}
impl TryFrom<&[u8]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<&str> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpxDigest::try_from)
}
}
impl TryFrom<&String> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl TryFrom<String> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for RpxDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes());
}
fn get_size_hint(&self) -> usize {
Self::SERIALIZED_SIZE
}
}
impl Deserializable for RpxDigest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
for inner in inner.iter_mut() {
let e = source.read_u64()?;
if e >= Felt::MODULUS {
return Err(DeserializationError::InvalidValue(String::from(
"Value not in the appropriate range",
)));
}
*inner = Felt::new(e);
}
Ok(Self(inner))
}
}
// ITERATORS
// ================================================================================================
impl IntoIterator for RpxDigest {
type Item = Felt;
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use alloc::string::String;
use rand_utils::rand_value;
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
use crate::utils::SliceReader;
#[test]
fn digest_serialization() {
let e1 = Felt::new(rand_value());
let e2 = Felt::new(rand_value());
let e3 = Felt::new(rand_value());
let e4 = Felt::new(rand_value());
let d1 = RpxDigest([e1, e2, e3, e4]);
let mut bytes = vec![];
d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len());
assert_eq!(bytes.len(), d1.get_size_hint());
let mut reader = SliceReader::new(&bytes);
let d2 = RpxDigest::read_from(&mut reader).unwrap();
assert_eq!(d1, d2);
}
#[test]
fn digest_encoding() {
let digest = RpxDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
let string: String = digest.into();
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
assert_eq!(digest, round_trip);
}
#[test]
fn test_conversions() {
let digest = RpxDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
// BY VALUE
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpxDigest = v.into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u64; DIGEST_SIZE] = digest.into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = digest.into();
let v2: RpxDigest = v.into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: String = digest.into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
// BY REF
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u64; DIGEST_SIZE] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = (&digest).into();
let v2: RpxDigest = (&v).into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: String = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
}
}

View file

@ -1,385 +0,0 @@
use core::ops::Range;
use super::{
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH,
ZERO,
};
mod digest;
pub use digest::{RpxDigest, RpxDigestError};
#[cfg(test)]
mod tests;
pub type CubicExtElement = CubeExtension<Felt>;
// HASHER IMPLEMENTATION
// ================================================================================================
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
///
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
///
/// The parameters used to instantiate the function are:
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
/// * State width: 12 field elements.
/// * Capacity size: 4 field elements.
/// * S-Box degree: 7.
/// * Rounds: There are 3 different types of rounds:
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` →
/// `apply_inv_sbox`.
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension
/// field).
/// - (M): `apply_mds` → `add_constants`.
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
///
/// The above parameters target a 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
/// a hash for the same set of elements using these functions will always produce the same
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
/// same result as hashing 8 elements which make up these digests using
/// [hash_elements()](Rpx256::hash_elements) function.
///
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
/// For example, if we take two field elements, serialize them to bytes and hash them using
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
/// deserialization procedure used by this function is different from the procedure used to
/// deserialize valid field elements.
///
/// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpx256::hash_elements) function rather than hashing the serialized bytes
/// using [hash()](Rpx256::hash) function.
///
/// ## Domain separation
/// [merge_in_domain()](Rpx256::merge_in_domain) hashes two digests into one digest with some domain
/// identifier and the current implementation sets the second capacity element to the value of
/// this domain identifier. Using a similar argument to the one formulated for domain separation
/// in Appendix C of the [specifications](https://eprint.iacr.org/2023/1045), one sees that doing
/// so degrades only pre-image resistance, from its initial bound of c.log_2(p), by as much as
/// the log_2 of the size of the domain identifier space. Since pre-image resistance becomes
/// the bottleneck for the security bound of the sponge in overwrite-mode only when it is
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
/// the size of the domain identifier space, including for padding, is less than 2^128.
///
/// ## Hashing of empty input
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
/// the benefit of requiring no calls to the RPX permutation when hashing empty input.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpx256();
impl Hasher for Rpx256 {
/// Rpx256 collision resistance is 128-bits.
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpxDigest;
fn hash(bytes: &[u8]) -> Self::Digest {
// initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH];
// determine the number of field elements needed to encode `bytes` when each field element
// represents at most 7 bytes.
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
// this to achieve:
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
state[CAPACITY_RANGE.start] =
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
// initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8];
// iterate the chunks of bytes, creating a field element from each chunk and copying it
// into the state.
//
// every time the rate range is filled, a permutation is performed. if the final value of
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
// and an additional permutation must be performed.
let mut current_chunk_idx = 0_usize;
// handle the case of an empty `bytes`
let last_chunk_idx = if num_field_elem == 0 {
current_chunk_idx
} else {
num_field_elem - 1
};
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
// copy the chunk into the buffer
if current_chunk_idx != last_chunk_idx {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
// needed to fill it
buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1;
}
current_chunk_idx += 1;
// set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range.
if rate_pos == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state);
0
} else {
rate_pos + 1
}
});
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
// don't need to apply any extra padding because the first capacity element contains a
// flag indicating the number of field elements constituting the last block when the latter
// is not divisible by `RATE_WIDTH`.
if rate_pos != 0 {
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the rate as hash result.
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// apply the RPX permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(Self::Digest::digests_as_elements(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element and
// set the first capacity element to 5.
// - if the value doesn't fit into a single field element, split it into two field elements,
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS {
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
} else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
}
// apply the RPX permutation and return the first four elements of the rate
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
impl ElementHasher for Rpx256 {
type BaseField = Felt;
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
// convert the elements into a list of base field elements
let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which
// is set to `elements.len() % RATE_WIDTH`.
let mut state = [ZERO; STATE_WIDTH];
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
// absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all
// elements have been absorbed
let mut i = 0;
for &element in elements.iter() {
state[RATE_RANGE.start + i] = element;
i += 1;
if i % RATE_WIDTH == 0 {
Self::apply_permutation(&mut state);
i = 0;
}
}
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
if i > 0 {
while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO;
i += 1;
}
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the state as hash result
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
// HASH FUNCTION IMPLEMENTATION
// ================================================================================================
impl Rpx256 {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
pub const STATE_WIDTH: usize = STATE_WIDTH;
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
/// Round constants added to the hasher state in the first half of the round.
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
/// Round constants added to the hasher state in the second half of the round.
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
// TRAIT PASS-THROUGH FUNCTIONS
// --------------------------------------------------------------------------------------------
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> RpxDigest {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
<Self as ElementHasher>::hash_elements(elements)
}
// DOMAIN IDENTIFIER
// --------------------------------------------------------------------------------------------
/// Returns a hash of two digests and a domain identifier.
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = RpxDigest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// set the second capacity element to the domain value. The first capacity element is used
// for padding purposes.
state[CAPACITY_RANGE.start + 1] = domain;
// apply the RPX permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
// RPX PERMUTATION
// --------------------------------------------------------------------------------------------
/// Applies RPX permutation to the provided state.
#[inline(always)]
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
Self::apply_fb_round(state, 0);
Self::apply_ext_round(state, 1);
Self::apply_fb_round(state, 2);
Self::apply_ext_round(state, 3);
Self::apply_fb_round(state, 4);
Self::apply_ext_round(state, 5);
Self::apply_final_round(state, 6);
}
// RPX PERMUTATION ROUND FUNCTIONS
// --------------------------------------------------------------------------------------------
/// (FB) round function.
#[inline(always)]
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
apply_mds(state);
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
add_constants(state, &ARK1[round]);
apply_sbox(state);
}
apply_mds(state);
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
add_constants(state, &ARK2[round]);
apply_inv_sbox(state);
}
}
/// (E) round function.
#[inline(always)]
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
// add constants
add_constants(state, &ARK1[round]);
// decompose the state into 4 elements in the cubic extension field and apply the power 7
// map to each of the elements
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
// decompose the state back into 12 base field elements
let arr_ext = [ext0, ext1, ext2, ext3];
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
.try_into()
.expect("shouldn't fail");
}
/// (M) round function.
#[inline(always)]
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
apply_mds(state);
add_constants(state, &ARK1[round]);
}
/// Computes an exponentiation to the power 7 in cubic extension field.
#[inline(always)]
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
let x2 = x.square();
let x4 = x2.square();
let x3 = x2 * x;
x3 * x4
}
}

View file

@ -1,186 +0,0 @@
use alloc::{collections::BTreeSet, vec::Vec};
use proptest::prelude::*;
use rand_utils::rand_value;
use super::{Felt, Hasher, Rpx256, StarkField, ZERO};
use crate::{hash::rescue::RpxDigest, ONE};
#[test]
fn hash_elements_vs_merge() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..].try_into().unwrap()),
];
let m_result = Rpx256::merge(&digests);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn merge_vs_merge_in_domain() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..].try_into().unwrap()),
];
let merge_result = Rpx256::merge(&digests);
// ----- merge with domain = 0 ----------------------------------------------------------------
// set domain to ZERO. This should not change the result.
let domain = ZERO;
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
assert_eq!(merge_result, merge_in_domain_result);
// ----- merge with domain = 1 ----------------------------------------------------------------
// set domain to ONE. This should change the result.
let domain = ONE;
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
assert_ne!(merge_result, merge_in_domain_result);
}
#[test]
fn hash_elements_vs_merge_with_int() {
let tmp = [Felt::new(rand_value()); 4];
let seed = RpxDigest::new(tmp);
// ----- value fits into a field element ------------------------------------------------------
let val: Felt = Felt::new(rand_value());
let m_result = Rpx256::merge_with_int(seed, val.as_int());
let mut elements = seed.as_elements().to_vec();
elements.push(val);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
// ----- value does not fit into a field element ----------------------------------------------
let val = Felt::MODULUS + 2;
let m_result = Rpx256::merge_with_int(seed, val);
let mut elements = seed.as_elements().to_vec();
elements.push(Felt::new(val));
elements.push(ONE);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_padding() {
// adding a zero bytes at the end of a byte string should result in a different hash
let r1 = Rpx256::hash(&[1_u8, 2, 3]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 0]);
assert_ne!(r1, r2);
// same as above but with bigger inputs
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
assert_ne!(r1, r2);
// same as above but with input splitting over two elements
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
assert_ne!(r1, r2);
// same as above but with multiple zeros
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements_padding() {
let e1 = [Felt::new(rand_value()); 2];
let e2 = [e1[0], e1[1], ZERO];
let r1 = Rpx256::hash_elements(&e1);
let r2 = Rpx256::hash_elements(&e2);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements() {
let elements = [
ZERO,
ONE,
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::new(7),
];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..8].try_into().unwrap()),
];
let m_result = Rpx256::merge(&digests);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_empty() {
let elements: Vec<Felt> = vec![];
let zero_digest = RpxDigest::default();
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_empty_bytes() {
let bytes: Vec<u8> = vec![];
let zero_digest = RpxDigest::default();
let h_result = Rpx256::hash(&bytes);
assert_eq!(zero_digest, h_result);
}
#[test]
fn sponge_bytes_with_remainder_length_wont_panic() {
// this test targets to assert that no panic will happen with the edge case of having an inputs
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
// size.
//
// this is a preliminary test to the fuzzy-stress of proptest.
Rpx256::hash(&[0; 113]);
}
#[test]
fn sponge_collision_for_wrapped_field_element() {
let a = Rpx256::hash(&[0; 8]);
let b = Rpx256::hash(&Felt::MODULUS.to_le_bytes());
assert_ne!(a, b);
}
#[test]
fn sponge_zeroes_collision() {
let mut zeroes = Vec::with_capacity(255);
let mut set = BTreeSet::new();
(0..255).for_each(|_| {
let hash = Rpx256::hash(&zeroes);
zeroes.push(0);
// panic if a collision was found
assert!(set.insert(hash));
});
}
proptest! {
#[test]
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
Rpx256::hash(bytes);
}
}

View file

@ -1,10 +0,0 @@
use rand_utils::rand_value;
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
#[test]
fn test_alphas() {
let e: Felt = Felt::new(rand_value());
let e_exp = e.exp(ALPHA);
assert_eq!(e, e_exp.exp(INV_ALPHA));
}

View file

@ -1,71 +0,0 @@
#![no_std]
#[macro_use]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
pub mod dsa;
pub mod hash;
pub mod merkle;
pub mod rand;
pub mod utils;
// RE-EXPORTS
// ================================================================================================
pub use winter_math::{
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
FieldElement, StarkField,
};
// TYPE ALIASES
// ================================================================================================
/// A group of four field elements in the Miden base field.
pub type Word = [Felt; WORD_SIZE];
// CONSTANTS
// ================================================================================================
/// Number of field elements in a word.
pub const WORD_SIZE: usize = 4;
/// Field element representing ZERO in the Miden base filed.
pub const ZERO: Felt = Felt::ZERO;
/// Field element representing ONE in the Miden base filed.
pub const ONE: Felt = Felt::ONE;
/// Array of field elements representing word of ZEROs in the Miden base field.
pub const EMPTY_WORD: [Felt; 4] = [ZERO; WORD_SIZE];
// TESTS
// ================================================================================================
#[test]
#[should_panic]
fn debug_assert_is_checked() {
// enforce the release checks to always have `RUSTFLAGS="-C debug-assertions".
//
// some upstream tests are performed with `debug_assert`, and we want to assert its correctness
// downstream.
//
// for reference, check
// https://github.com/0xPolygonMiden/miden-vm/issues/433
debug_assert!(false);
}
#[test]
#[should_panic]
#[allow(arithmetic_overflow)]
fn overflow_panics_for_test() {
// overflows might be disabled if tests are performed in release mode. these are critical,
// mandatory checks as overflows might be attack vectors.
//
// to enable overflow checks in release mode, ensure `RUSTFLAGS="-C overflow-checks"`
let a = 1_u64;
let b = 64;
assert_ne!(a << b, 0);
}

View file

@ -1,219 +0,0 @@
use std::time::Instant;
use clap::Parser;
use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt},
Felt, Word, EMPTY_WORD, ONE,
};
use rand::{prelude::IteratorRandom, thread_rng, Rng};
use rand_utils::rand_value;
#[derive(Parser, Debug)]
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
pub struct BenchmarkCmd {
/// Size of the tree
#[clap(short = 's', long = "size", default_value = "1000000")]
size: usize,
/// Number of insertions
#[clap(short = 'i', long = "insertions", default_value = "1000")]
insertions: usize,
/// Number of updates
#[clap(short = 'u', long = "updates", default_value = "1000")]
updates: usize,
}
fn main() {
benchmark_smt();
}
/// Run a benchmark for [`Smt`].
pub fn benchmark_smt() {
let args = BenchmarkCmd::parse();
let tree_size = args.size;
let insertions = args.insertions;
let updates = args.updates;
assert!(updates <= tree_size, "Cannot update more than `size`");
// prepare the `leaves` vector for tree creation
let mut entries = Vec::new();
for i in 0..tree_size {
let key = rand_value::<RpoDigest>();
let value = [ONE, ONE, ONE, Felt::new(i as u64)];
entries.push((key, value));
}
let mut tree = construction(entries.clone(), tree_size).unwrap();
insertion(&mut tree.clone(), insertions).unwrap();
batched_insertion(&mut tree.clone(), insertions).unwrap();
batched_update(&mut tree.clone(), entries, updates).unwrap();
proof_generation(&mut tree).unwrap();
}
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt, MerkleError> {
println!("Running a construction benchmark:");
let now = Instant::now();
let tree = Smt::with_entries(entries)?;
let elapsed = now.elapsed().as_secs_f32();
println!("Constructed an SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Number of leaf nodes: {}\n", tree.leaves().count());
Ok(tree)
}
/// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
println!("Running an insertion benchmark:");
let size = tree.num_leaves();
let mut insertion_times = Vec::new();
for i in 0..insertions {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
let now = Instant::now();
tree.insert(test_key, test_value);
let elapsed = now.elapsed();
insertion_times.push(elapsed.as_micros());
}
println!(
"The average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (insertions as f64),
);
Ok(())
}
pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
println!("Running a batched insertion benchmark:");
let size = tree.num_leaves();
let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
(key, value)
})
.collect();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"The average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / insertions as f64, // time in μs
);
let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"The average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / insertions as f64, // time in μs
);
println!(
"The average batch insertion time measured by a {insertions}-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);
println!();
Ok(())
}
pub fn batched_update(
tree: &mut Smt,
entries: Vec<(RpoDigest, Word)>,
updates: usize,
) -> Result<(), MerkleError> {
const REMOVAL_PROBABILITY: f64 = 0.2;
println!("Running a batched update benchmark:");
let size = tree.num_leaves();
let mut rng = thread_rng();
let new_pairs =
entries
.into_iter()
.choose_multiple(&mut rng, updates)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};
(key, value)
});
assert_eq!(new_pairs.len(), updates);
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"The average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / updates as f64, // time in μs
);
println!(
"The average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / updates as f64, // time in μs
);
println!(
"The average batch update time measured by a {updates}-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);
println!();
Ok(())
}
/// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_PROOFS: usize = 100;
println!("Running a proof generation benchmark:");
let mut insertion_times = Vec::new();
let size = tree.num_leaves();
for i in 0..NUM_PROOFS {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
tree.insert(test_key, test_value);
let now = Instant::now();
let _proof = tree.open(&test_key);
insertion_times.push(now.elapsed().as_micros());
}
println!(
"The average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
);
Ok(())
}

File diff suppressed because it is too large Load diff

View file

@ -1,36 +0,0 @@
use thiserror::Error;
use super::{NodeIndex, RpoDigest};
#[derive(Debug, Error)]
pub enum MerkleError {
#[error("expected merkle root {expected_root} found {actual_root}")]
ConflictingRoots {
expected_root: RpoDigest,
actual_root: RpoDigest,
},
#[error("provided merkle tree depth {0} is too small")]
DepthTooSmall(u8),
#[error("provided merkle tree depth {0} is too big")]
DepthTooBig(u64),
#[error("multiple values provided for merkle tree index {0}")]
DuplicateValuesForIndex(u64),
#[error("node index value {value} is not valid for depth {depth}")]
InvalidNodeIndex { depth: u8, value: u64 },
#[error("provided node index depth {provided} does not match expected depth {expected}")]
InvalidNodeIndexDepth { expected: u8, provided: u8 },
#[error("merkle subtree depth {subtree_depth} exceeds merkle tree depth {tree_depth}")]
SubtreeDepthExceedsDepth { subtree_depth: u8, tree_depth: u8 },
#[error("number of entries in the merkle tree exceeds the maximum of {0}")]
TooManyEntries(usize),
#[error("node index `{0}` not found in the tree")]
NodeIndexNotFoundInTree(NodeIndex),
#[error("node {0:?} with index `{1}` not found in the store")]
NodeIndexNotFoundInStore(RpoDigest, NodeIndex),
#[error("number of provided merkle tree leaves {0} is not a power of two")]
NumLeavesNotPowerOfTwo(usize),
#[error("root {0:?} is not in the store")]
RootNotInStore(RpoDigest),
#[error("partial smt does not track the merkle path for key {0} so updating it would produce a different root compared to the same update in the full tree")]
UntrackedKey(RpoDigest),
}

View file

@ -1,244 +0,0 @@
use core::fmt::Display;
use super::{Felt, MerkleError, RpoDigest};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
// NODE INDEX
// ================================================================================================
/// Address to an arbitrary node in a binary tree using level order form.
///
/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
/// are numbered from $0..(2^d)-1$. Example:
///
/// ```ignore
/// depth
/// 0 0
/// 1 0 1
/// 2 0 1 2 3
/// 3 0 1 2 3 4 5 6 7
/// ```
///
/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
/// $(1, 1)$.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct NodeIndex {
depth: u8,
value: u64,
}
impl NodeIndex {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Creates a new node index.
///
/// # Errors
/// Returns an error if the `value` is greater than or equal to 2^{depth}.
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
if (64 - value.leading_zeros()) > depth as u32 {
Err(MerkleError::InvalidNodeIndex { depth, value })
} else {
Ok(Self { depth, value })
}
}
/// Creates a new node index without checking its validity.
pub const fn new_unchecked(depth: u8, value: u64) -> Self {
debug_assert!((64 - value.leading_zeros()) <= depth as u32);
Self { depth, value }
}
/// Creates a new node index for testing purposes.
///
/// # Panics
/// Panics if the `value` is greater than or equal to 2^{depth}.
#[cfg(test)]
pub fn make(depth: u8, value: u64) -> Self {
Self::new(depth, value).unwrap()
}
/// Creates a node index from a pair of field elements representing the depth and value.
///
/// # Errors
/// Returns an error if:
/// - `depth` doesn't fit in a `u8`.
/// - `value` is greater than or equal to 2^{depth}.
pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
let depth = depth.as_int();
let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
let value = value.as_int();
Self::new(depth, value)
}
/// Creates a new node index pointing to the root of the tree.
pub const fn root() -> Self {
Self { depth: 0, value: 0 }
}
/// Computes sibling index of the current node.
pub const fn sibling(mut self) -> Self {
self.value ^= 1;
self
}
/// Returns left child index of the current node.
pub const fn left_child(mut self) -> Self {
self.depth += 1;
self.value <<= 1;
self
}
/// Returns right child index of the current node.
pub const fn right_child(mut self) -> Self {
self.depth += 1;
self.value = (self.value << 1) + 1;
self
}
/// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
/// a new value instead of mutating `self`.
pub const fn parent(mut self) -> Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
/// Builds a node to be used as input of a hash function when computing a Merkle path.
///
/// Will evaluate the parity of the current instance to define the result.
pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
if self.is_value_odd() {
[sibling, slf]
} else {
[slf, sibling]
}
}
/// Returns the scalar representation of the depth/value pair.
///
/// It is computed as `2^depth + value`.
pub const fn to_scalar_index(&self) -> u64 {
(1 << self.depth as u64) + self.value
}
/// Returns the depth of the current instance.
pub const fn depth(&self) -> u8 {
self.depth
}
/// Returns the value of this index.
pub const fn value(&self) -> u64 {
self.value
}
/// Returns `true` if the current instance points to a right sibling node.
pub const fn is_value_odd(&self) -> bool {
(self.value & 1) == 1
}
/// Returns `true` if the depth is `0`.
pub const fn is_root(&self) -> bool {
self.depth == 0
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Traverses one level towards the root, decrementing the depth by `1`.
pub fn move_up(&mut self) {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
}
/// Traverses towards the root until the specified depth is reached.
///
/// Assumes that the specified depth is smaller than the current depth.
pub fn move_up_to(&mut self, depth: u8) {
debug_assert!(depth < self.depth);
let delta = self.depth.saturating_sub(depth);
self.depth = self.depth.saturating_sub(delta);
self.value >>= delta as u32;
}
}
impl Display for NodeIndex {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "depth={}, value={}", self.depth, self.value)
}
}
impl Serializable for NodeIndex {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.depth);
target.write_u64(self.value);
}
}
impl Deserializable for NodeIndex {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let depth = source.read_u8()?;
let value = source.read_u64()?;
NodeIndex::new(depth, value)
.map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use proptest::prelude::*;
use super::*;
#[test]
fn test_node_index_value_too_high() {
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
let err = NodeIndex::new(0, 1).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
let err = NodeIndex::new(1, 2).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
let err = NodeIndex::new(2, 4).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
let err = NodeIndex::new(3, 8).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
}
#[test]
fn test_node_index_can_represent_depth_64() {
assert!(NodeIndex::new(64, u64::MAX).is_ok());
}
prop_compose! {
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
// unwrap never panics because the range of depth is 0..u64::BITS
let mut depth = value.ilog2() as u8;
if value > (1 << depth) { // round up
depth += 1;
}
NodeIndex::new(depth, value).unwrap()
}
}
proptest! {
#[test]
fn arbitrary_index_wont_panic_on_move_up(
mut index in node_index(),
count in prop::num::u8::ANY,
) {
for _ in 0..count {
index.move_up();
}
}
}
}

View file

@ -1,450 +0,0 @@
use alloc::{string::String, vec::Vec};
use core::{fmt, ops::Deref, slice};
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
use crate::utils::{uninit_vector, word_to_hex};
// MERKLE TREE
// ================================================================================================
/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTree {
nodes: Vec<RpoDigest>,
}
impl MerkleTree {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a Merkle tree instantiated from the provided leaves.
///
/// # Errors
/// Returns an error if the number of leaves is smaller than two or is not a power of two.
pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
where
T: AsRef<[Word]>,
{
let leaves = leaves.as_ref();
let n = leaves.len();
if n <= 1 {
return Err(MerkleError::DepthTooSmall(n as u8));
} else if !n.is_power_of_two() {
return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
}
// create un-initialized vector to hold all tree nodes
let mut nodes = unsafe { uninit_vector(2 * n) };
nodes[0] = RpoDigest::default();
// copy leaves into the second part of the nodes vector
nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
*node = RpoDigest::from(*leaf);
});
// re-interpret nodes as an array of two nodes fused together
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
// `self`).
let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
let pairs = unsafe { slice::from_raw_parts(ptr, n) };
// calculate all internal tree nodes
for i in (1..n).rev() {
nodes[i] = Rpo256::merge(&pairs[i]);
}
Ok(Self { nodes })
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the root of this Merkle tree.
pub fn root(&self) -> RpoDigest {
self.nodes[1]
}
/// Returns the depth of this Merkle tree.
///
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
pub fn depth(&self) -> u8 {
(self.nodes.len() / 2).ilog2() as u8
}
/// Returns a node at the specified depth and index value.
///
/// # Errors
/// Returns an error if:
/// * The specified depth is greater than the depth of the tree.
/// * The specified index is not valid for the specified depth.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
}
let pos = index.to_scalar_index() as usize;
Ok(self.nodes[pos])
}
/// Returns a Merkle path to the node at the specified depth and index value. The node itself
/// is not included in the path.
///
/// # Errors
/// Returns an error if:
/// * The specified depth is greater than the depth of the tree.
/// * The specified value is not valid for the specified depth.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
}
// TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so
// we always use inlined `for` instead of `while`? the reason to use `for` is because its
// easier for the compiler to vectorize.
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let sibling = index.sibling().to_scalar_index() as usize;
path.push(self.nodes[sibling]);
index.move_up();
}
debug_assert!(index.is_root(), "the path walk must go all the way to the root");
Ok(path.into())
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [MerkleTree].
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
let leaves_start = self.nodes.len() / 2;
self.nodes
.iter()
.skip(leaves_start)
.enumerate()
.map(|(i, v)| (i as u64, v.deref()))
}
/// Returns n iterator over every inner node of this [MerkleTree].
///
/// The iterator order is unspecified.
pub fn inner_nodes(&self) -> InnerNodeIterator {
InnerNodeIterator {
nodes: &self.nodes,
index: 1, // index 0 is just padding, start at 1
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Replaces the leaf at the specified index with the provided value.
///
/// # Errors
/// Returns an error if the specified index value is not a valid leaf value for this tree.
pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
let mut index = NodeIndex::new(self.depth(), index_value)?;
// we don't need to copy the pairs into a new address as we are logically guaranteed to not
// overlap write instructions. however, it's important to bind the lifetime of pairs to
// `self.nodes` so the compiler will never move one without moving the other.
debug_assert_eq!(self.nodes.len() & 1, 0);
let n = self.nodes.len() / 2;
// Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
// digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
// `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
// logically guaranteed to not overlap write positions as the write index is always half
// the index from which we read the digest input.
let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
// update the current node
let pos = index.to_scalar_index() as usize;
self.nodes[pos] = value.into();
// traverse to the root, updating each node with the merged values of its parents
for _ in 0..index.depth() {
index.move_up();
let pos = index.to_scalar_index() as usize;
let value = Rpo256::merge(&pairs[pos]);
self.nodes[pos] = value;
}
Ok(())
}
}
// CONVERSIONS
// ================================================================================================
impl TryFrom<&[Word]> for MerkleTree {
type Error = MerkleError;
fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
MerkleTree::new(value)
}
}
impl TryFrom<&[RpoDigest]> for MerkleTree {
type Error = MerkleError;
fn try_from(value: &[RpoDigest]) -> Result<Self, Self::Error> {
let value: Vec<Word> = value.iter().map(|v| *v.deref()).collect();
MerkleTree::new(value)
}
}
// ITERATORS
// ================================================================================================
/// An iterator over every inner node of the [MerkleTree].
///
/// Use this to extract the data of the tree, there is no guarantee on the order of the elements.
pub struct InnerNodeIterator<'a> {
nodes: &'a Vec<RpoDigest>,
index: usize,
}
impl Iterator for InnerNodeIterator<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.nodes.len() / 2 {
let value = self.index;
let left = self.index * 2;
let right = left + 1;
self.index += 1;
Some(InnerNodeInfo {
value: self.nodes[value],
left: self.nodes[left],
right: self.nodes[right],
})
} else {
None
}
}
}
// UTILITY FUNCTIONS
// ================================================================================================
/// Utility to visualize a [MerkleTree] in text.
pub fn tree_to_text(tree: &MerkleTree) -> Result<String, fmt::Error> {
let indent = " ";
let mut s = String::new();
s.push_str(&word_to_hex(&tree.root())?);
s.push('\n');
for d in 1..=tree.depth() {
let entries = 2u64.pow(d.into());
for i in 0..entries {
let index = NodeIndex::new(d, i).expect("The index must always be valid");
let node = tree.get_node(index).expect("The node must always be found");
for _ in 0..d {
s.push_str(indent);
}
s.push_str(&word_to_hex(&node)?);
s.push('\n');
}
}
Ok(s)
}
/// Utility to visualize a [MerklePath] in text.
pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
let mut s = String::new();
s.push('[');
for el in path.iter() {
s.push_str(&word_to_hex(el)?);
s.push_str(", ");
}
// remove the last ", "
if !path.is_empty() {
s.pop();
s.pop();
}
s.push(']');
Ok(s)
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use core::mem::size_of;
use proptest::prelude::*;
use super::*;
use crate::{
merkle::{digests_to_words, int_to_leaf, int_to_node},
Felt, WORD_SIZE,
};
const LEAVES4: [RpoDigest; WORD_SIZE] =
[int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
const LEAVES8: [RpoDigest; 8] = [
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
int_to_node(8),
];
#[test]
fn build_merkle_tree() {
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
assert_eq!(8, tree.nodes.len());
// leaves were copied correctly
for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
assert_eq!(a, b);
}
let (root, node2, node3) = compute_internal_nodes();
assert_eq!(root, tree.nodes[1]);
assert_eq!(node2, tree.nodes[2]);
assert_eq!(node3, tree.nodes[3]);
assert_eq!(root, tree.root());
}
#[test]
fn get_leaf() {
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
// check depth 2
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
// check depth 1
let (_, node2, node3) = compute_internal_nodes();
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
}
#[test]
fn get_path() {
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
let (_, node2, node3) = compute_internal_nodes();
// check depth 2
assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
// check depth 1
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
}
#[test]
fn update_leaf() {
let mut tree = super::MerkleTree::new(digests_to_words(&LEAVES8)).unwrap();
// update one leaf
let value = 3;
let new_node = int_to_leaf(9);
let mut expected_leaves = digests_to_words(&LEAVES8);
expected_leaves[value as usize] = new_node;
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
tree.update_leaf(value, new_node).unwrap();
assert_eq!(expected_tree.nodes, tree.nodes);
// update another leaf
let value = 6;
let new_node = int_to_leaf(10);
expected_leaves[value as usize] = new_node;
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
tree.update_leaf(value, new_node).unwrap();
assert_eq!(expected_tree.nodes, tree.nodes);
}
#[test]
fn nodes() -> Result<(), MerkleError> {
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
let root = tree.root();
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let expected = vec![
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
];
assert_eq!(nodes, expected);
Ok(())
}
proptest! {
#[test]
fn arbitrary_word_can_be_represented_as_digest(
a in prop::num::u64::ANY,
b in prop::num::u64::ANY,
c in prop::num::u64::ANY,
d in prop::num::u64::ANY,
) {
// this test will assert the memory equivalence between word and digest.
// it is used to safeguard the `[MerkleTee::update_leaf]` implementation
// that assumes this equivalence.
// build a word and copy it to another address as digest
let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
let digest = RpoDigest::from(word);
// assert the addresses are different
let word_ptr = word.as_ptr() as *const u8;
let digest_ptr = digest.as_ptr() as *const u8;
assert_ne!(word_ptr, digest_ptr);
// compare the bytes representation
let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
assert_eq!(word_bytes, digest_bytes);
}
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
let node2 =
Rpo256::hash_elements(&[Word::from(LEAVES4[0]), Word::from(LEAVES4[1])].concat());
let node3 =
Rpo256::hash_elements(&[Word::from(LEAVES4[2]), Word::from(LEAVES4[3])].concat());
let root = Rpo256::merge(&[node2, node3]);
(root, node2, node3)
}
}

View file

@ -1,46 +0,0 @@
/// Iterate over the bits of a `usize` and yields the bit positions for the true bits.
pub struct TrueBitPositionIterator {
value: usize,
}
impl TrueBitPositionIterator {
pub fn new(value: usize) -> TrueBitPositionIterator {
TrueBitPositionIterator { value }
}
}
impl Iterator for TrueBitPositionIterator {
type Item = u32;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
// trailing_zeros is computed with the intrinsic cttz. [Rust 1.67.0] x86 uses the `bsf`
// instruction. AArch64 uses the `rbit clz` instructions.
let zeros = self.value.trailing_zeros();
if zeros == usize::BITS {
None
} else {
let bit_position = zeros;
let mask = 1 << bit_position;
self.value ^= mask;
Some(bit_position)
}
}
}
impl DoubleEndedIterator for TrueBitPositionIterator {
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
// trailing_zeros is computed with the intrinsic ctlz. [Rust 1.67.0] x86 uses the `bsr`
// instruction. AArch64 uses the `clz` instruction.
let zeros = self.value.leading_zeros();
if zeros == usize::BITS {
None
} else {
let bit_position = usize::BITS - zeros - 1;
let mask = 1 << bit_position;
self.value ^= mask;
Some(bit_position)
}
}
}

View file

@ -1,18 +0,0 @@
use alloc::vec::Vec;
use super::super::RpoDigest;
/// Container for the update data of a [super::PartialMmr]
#[derive(Debug)]
pub struct MmrDelta {
/// The new version of the [super::Mmr]
pub forest: usize,
/// Update data.
///
/// The data is packed as follows:
/// 1. All the elements needed to perform authentication path updates. These are the right
/// siblings required to perform tree merges on the [super::PartialMmr].
/// 2. The new peaks.
pub data: Vec<RpoDigest>,
}

View file

@ -1,27 +0,0 @@
use alloc::string::String;
use thiserror::Error;
use crate::merkle::MerkleError;
#[derive(Debug, Error)]
pub enum MmrError {
#[error("mmr does not contain position {0}")]
PositionNotFound(usize),
#[error("mmr peaks are invalid: {0}")]
InvalidPeaks(String),
#[error(
"mmr peak does not match the computed merkle root of the provided authentication path"
)]
PeakPathMismatch,
#[error("requested peak index is {peak_idx} but the number of peaks is {peaks_len}")]
PeakOutOfBounds { peak_idx: usize, peaks_len: usize },
#[error("invalid mmr update")]
InvalidUpdate,
#[error("mmr does not contain a peak with depth {0}")]
UnknownPeak(u8),
#[error("invalid merkle path")]
InvalidMerklePath(#[source] MerkleError),
#[error("merkle root computation failed")]
MerkleRootComputationFailed(#[source] MerkleError),
}

View file

@ -1,447 +0,0 @@
//! A fully materialized Merkle mountain range (MMR).
//!
//! A MMR is a forest structure, i.e. it is an ordered set of disjoint rooted trees. The trees are
//! ordered by size, from the most to least number of leaves. Every tree is a perfect binary tree,
//! meaning a tree has all its leaves at the same depth, and every inner node has a branch-factor
//! of 2 with both children set.
//!
//! Additionally the structure only supports adding leaves to the right-most tree, the one with the
//! least number of leaves. The structure preserves the invariant that each tree has different
//! depths, i.e. as part of adding a new element to the forest the trees with same depth are
//! merged, creating a new tree with depth d+1, this process is continued until the property is
//! reestablished.
use alloc::vec::Vec;
use super::{
super::{InnerNodeInfo, MerklePath},
bit::TrueBitPositionIterator,
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
RpoDigest,
};
// MMR
// ===============================================================================================
/// A fully materialized Merkle Mountain Range, with every tree in the forest and all their
/// elements.
///
/// Since this is a full representation of the MMR, elements are never removed and the MMR will
/// grow roughly `O(2n)` in number of leaf elements.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Mmr {
/// Refer to the `forest` method documentation for details of the semantics of this value.
pub(super) forest: usize,
/// Contains every element of the forest.
///
/// The trees are in postorder sequential representation. This representation allows for all
/// the elements of every tree in the forest to be stored in the same sequential buffer. It
/// also means new elements can be added to the forest, and merging of trees is very cheap with
/// no need to copy elements.
pub(super) nodes: Vec<RpoDigest>,
}
impl Default for Mmr {
fn default() -> Self {
Self::new()
}
}
impl Mmr {
// CONSTRUCTORS
// ============================================================================================
/// Constructor for an empty `Mmr`.
pub fn new() -> Mmr {
Mmr { forest: 0, nodes: Vec::new() }
}
// ACCESSORS
// ============================================================================================
/// Returns the MMR forest representation.
///
/// The forest value has the following interpretations:
/// - its value is the number of elements in the forest
/// - bit count corresponds to the number of trees in the forest
/// - each true bit position determines the depth of a tree in the forest
pub const fn forest(&self) -> usize {
self.forest
}
// FUNCTIONALITY
// ============================================================================================
/// Returns an [MmrProof] for the leaf at the specified position.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if the specified leaf position is out of bounds for this MMR.
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
self.open_at(pos, self.forest)
}
/// Returns an [MmrProof] for the leaf at the specified position using the state of the MMR
/// at the specified `forest`.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if:
/// - The specified leaf position is out of bounds for this MMR.
/// - The specified `forest` value is not valid for this MMR.
pub fn open_at(&self, pos: usize, forest: usize) -> Result<MmrProof, MmrError> {
// find the target tree responsible for the MMR position
let tree_bit =
leaf_to_corresponding_tree(pos, forest).ok_or(MmrError::PositionNotFound(pos))?;
// isolate the trees before the target
let forest_before = forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before);
// update the value position from global to the target tree
let relative_pos = pos - forest_before;
// collect the path and the final index of the target value
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
Ok(MmrProof {
forest,
position: pos,
merkle_path: MerklePath::new(path),
})
}
/// Returns the leaf value at position `pos`.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> {
// find the target tree responsible for the MMR position
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
// isolate the trees before the target
let forest_before = self.forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before);
// update the value position from global to the target tree
let relative_pos = pos - forest_before;
// collect the path and the final index of the target value
let (value, _) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
Ok(value)
}
/// Adds a new element to the MMR.
pub fn add(&mut self, el: RpoDigest) {
// Note: every node is also a tree of size 1, adding an element to the forest creates a new
// rooted-tree of size 1. This may temporarily break the invariant that every tree in the
// forest has different sizes, the loop below will eagerly merge trees of same size and
// restore the invariant.
self.nodes.push(el);
let mut left_offset = self.nodes.len().saturating_sub(2);
let mut right = el;
let mut left_tree = 1;
while self.forest & left_tree != 0 {
right = Rpo256::merge(&[self.nodes[left_offset], right]);
self.nodes.push(right);
left_offset = left_offset.saturating_sub(nodes_in_forest(left_tree));
left_tree <<= 1;
}
self.forest += 1;
}
/// Returns the current peaks of the MMR.
pub fn peaks(&self) -> MmrPeaks {
self.peaks_at(self.forest).expect("failed to get peaks at current forest")
}
/// Returns the peaks of the MMR at the state specified by `forest`.
///
/// # Errors
/// Returns an error if the specified `forest` value is not valid for this MMR.
pub fn peaks_at(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
if forest > self.forest {
return Err(MmrError::InvalidPeaks(format!(
"requested forest {forest} exceeds current forest {}",
self.forest
)));
}
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
.rev()
.map(|bit| nodes_in_forest(1 << bit))
.scan(0, |offset, el| {
*offset += el;
Some(*offset)
})
.map(|offset| self.nodes[offset - 1])
.collect();
// Safety: the invariant is maintained by the [Mmr]
let peaks = MmrPeaks::new(forest, peaks).unwrap();
Ok(peaks)
}
/// Compute the required update to `original_forest`.
///
/// The result is a packed sequence of the authentication elements required to update the trees
/// that have been merged together, followed by the new peaks of the [Mmr].
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
if to_forest > self.forest || from_forest > to_forest {
return Err(MmrError::InvalidPeaks(format!("to_forest {to_forest} exceeds the current forest {} or from_forest {from_forest} exceeds to_forest", self.forest)));
}
if from_forest == to_forest {
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
}
let mut result = Vec::new();
// Find the largest tree in this [Mmr] which is new to `from_forest`.
let candidate_trees = to_forest ^ from_forest;
let mut new_high = 1 << candidate_trees.ilog2();
// Collect authentication nodes used for tree merges
// ----------------------------------------------------------------------------------------
// Find the trees from `from_forest` that have been merged into `new_high`.
let mut merges = from_forest & (new_high - 1);
// Find the peaks that are common to `from_forest` and this [Mmr]
let common_trees = from_forest ^ merges;
if merges != 0 {
// Skip the smallest trees unknown to `from_forest`.
let mut target = 1 << merges.trailing_zeros();
// Collect siblings required to computed the merged tree's peak
while target < new_high {
// Computes the offset to the smallest know peak
// - common_trees: peaks unchanged in the current update, target comes after these.
// - merges: peaks that have not been merged so far, target comes after these.
// - target: tree from which to load the sibling. On the first iteration this is a
// value known by the partial mmr, on subsequent iterations this value is to be
// computed from the known peaks and provided authentication nodes.
let known = nodes_in_forest(common_trees | merges | target);
let sibling = nodes_in_forest(target);
result.push(self.nodes[known + sibling - 1]);
// Update the target and account for tree merges
target <<= 1;
while merges & target != 0 {
target <<= 1;
}
// Remove the merges done so far
merges ^= merges & (target - 1);
}
} else {
// The new high tree may not be the result of any merges, if it is smaller than all the
// trees of `from_forest`.
new_high = 0;
}
// Collect the new [Mmr] peaks
// ----------------------------------------------------------------------------------------
let mut new_peaks = to_forest ^ common_trees ^ new_high;
let old_peaks = to_forest ^ new_peaks;
let mut offset = nodes_in_forest(old_peaks);
while new_peaks != 0 {
let target = 1 << new_peaks.ilog2();
offset += nodes_in_forest(target);
result.push(self.nodes[offset - 1]);
new_peaks ^= target;
}
Ok(MmrDelta { forest: to_forest, data: result })
}
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
pub fn inner_nodes(&self) -> MmrNodes {
MmrNodes {
mmr: self,
forest: 0,
last_right: 0,
index: 0,
}
}
// UTILITIES
// ============================================================================================
/// Internal function used to collect the Merkle path of a value.
///
/// The arguments are relative to the target tree. To compute the opening of the second leaf
/// for a tree with depth 2 in the forest `0b110`:
///
/// - `tree_bit`: Depth of the target tree, e.g. 2 for the smallest tree.
/// - `relative_pos`: 0-indexed leaf position in the target tree, e.g. 1 for the second leaf.
/// - `index_offset`: Node count prior to the target tree, e.g. 7 for the tree of depth 3.
fn collect_merkle_path_and_value(
&self,
tree_bit: u32,
relative_pos: usize,
index_offset: usize,
) -> (RpoDigest, Vec<RpoDigest>) {
// see documentation of `leaf_to_corresponding_tree` for details
let tree_depth = (tree_bit + 1) as usize;
let mut path = Vec::with_capacity(tree_depth);
// The tree walk below goes from the root to the leaf, compute the root index to start
let mut forest_target = 1usize << tree_bit;
let mut index = nodes_in_forest(forest_target) - 1;
// Loop until the leaf is reached
while forest_target > 1 {
// Update the depth of the tree to correspond to a subtree
forest_target >>= 1;
// compute the indices of the right and left subtrees based on the post-order
let right_offset = index - 1;
let left_offset = right_offset - nodes_in_forest(forest_target);
let left_or_right = relative_pos & forest_target;
let sibling = if left_or_right != 0 {
// going down the right subtree, the right child becomes the new root
index = right_offset;
// and the left child is the authentication
self.nodes[index_offset + left_offset]
} else {
index = left_offset;
self.nodes[index_offset + right_offset]
};
path.push(sibling);
}
debug_assert!(path.len() == tree_depth - 1);
// the rest of the codebase has the elements going from leaf to root, adjust it here for
// easy of use/consistency sake
path.reverse();
let value = self.nodes[index_offset + index];
(value, path)
}
}
// CONVERSIONS
// ================================================================================================
impl<T> From<T> for Mmr
where
T: IntoIterator<Item = RpoDigest>,
{
fn from(values: T) -> Self {
let mut mmr = Mmr::new();
for v in values {
mmr.add(v)
}
mmr
}
}
// ITERATOR
// ===============================================================================================
/// Yields inner nodes of the [Mmr].
pub struct MmrNodes<'a> {
/// [Mmr] being yielded, when its `forest` value is matched, the iterations is finished.
mmr: &'a Mmr,
/// Keeps track of the left nodes yielded so far waiting for a right pair, this matches the
/// semantics of the [Mmr]'s forest attribute, since that too works as a buffer of left nodes
/// waiting for a pair to be hashed together.
forest: usize,
/// Keeps track of the last right node yielded, after this value is set, the next iteration
/// will be its parent with its corresponding left node that has been yield already.
last_right: usize,
/// The current index in the `nodes` vector.
index: usize,
}
impl Iterator for MmrNodes<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
debug_assert!(self.last_right.count_ones() <= 1, "last_right tracks zero or one element");
// only parent nodes are emitted, remove the single node tree from the forest
let target = self.mmr.forest & (usize::MAX << 1);
if self.forest < target {
if self.last_right == 0 {
// yield the left leaf
debug_assert!(self.last_right == 0, "left must be before right");
self.forest |= 1;
self.index += 1;
// yield the right leaf
debug_assert!((self.forest & 1) == 1, "right must be after left");
self.last_right |= 1;
self.index += 1;
};
debug_assert!(
self.forest & self.last_right != 0,
"parent requires both a left and right",
);
// compute the number of nodes in the right tree, this is the offset to the
// previous left parent
let right_nodes = nodes_in_forest(self.last_right);
// the next parent position is one above the position of the pair
let parent = self.last_right << 1;
// the left node has been paired and the current parent yielded, removed it from the
// forest
self.forest ^= self.last_right;
if self.forest & parent == 0 {
// this iteration yielded the left parent node
debug_assert!(self.forest & 1 == 0, "next iteration yields a left leaf");
self.last_right = 0;
self.forest ^= parent;
} else {
// the left node of the parent level has been yielded already, this iteration
// was the right parent. Next iteration yields their parent.
self.last_right = parent;
}
// yields a parent
let value = self.mmr.nodes[self.index];
let right = self.mmr.nodes[self.index - 1];
let left = self.mmr.nodes[self.index - 1 - right_nodes];
self.index += 1;
let node = InnerNodeInfo { value, left, right };
Some(node)
} else {
None
}
}
}
// UTILITIES
// ===============================================================================================
/// Return a bitmask for the bits including and above the given position.
pub(crate) const fn high_bitmask(bit: u32) -> usize {
if bit > usize::BITS - 1 {
0
} else {
usize::MAX << bit
}
}

View file

@ -1,191 +0,0 @@
//! Index for nodes of a binary tree based on an in-order tree walk.
//!
//! In-order walks have the parent node index split its left and right subtrees. All the left
//! children have indexes lower than the parent, meanwhile all the right subtree higher indexes.
//! This property makes it is easy to compute changes to the index by adding or subtracting the
//! leaves count.
use core::num::NonZeroUsize;
use winter_utils::{Deserializable, Serializable};
// IN-ORDER INDEX
// ================================================================================================
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct InOrderIndex {
idx: usize,
}
impl InOrderIndex {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [InOrderIndex] instantiated from the provided value.
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
InOrderIndex { idx: idx.get() }
}
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
///
/// # Panics:
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
// implementation only works 1-indexed counting.
let pos = leaf + 1;
InOrderIndex { idx: pos * 2 - 1 }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// True if the index is pointing at a leaf.
///
/// Every odd number represents a leaf.
pub fn is_leaf(&self) -> bool {
self.idx & 1 == 1
}
/// Returns true if this note is a left child of its parent.
pub fn is_left_child(&self) -> bool {
self.parent().left_child() == *self
}
/// Returns the level of the index.
///
/// Starts at level zero for leaves and increases by one for each parent.
pub fn level(&self) -> u32 {
self.idx.trailing_zeros()
}
/// Returns the index of the left child.
///
/// # Panics:
/// If the index corresponds to a leaf.
pub fn left_child(&self) -> InOrderIndex {
// The left child is itself a parent, with an index that splits its left/right subtrees. To
// go from the parent index to its left child, it is only necessary to subtract the count
// of elements on the child's right subtree + 1.
let els = 1 << (self.level() - 1);
InOrderIndex { idx: self.idx - els }
}
/// Returns the index of the right child.
///
/// # Panics:
/// If the index corresponds to a leaf.
pub fn right_child(&self) -> InOrderIndex {
// To compute the index of the parent of the right subtree it is sufficient to add the size
// of its left subtree + 1.
let els = 1 << (self.level() - 1);
InOrderIndex { idx: self.idx + els }
}
/// Returns the index of the parent node.
pub fn parent(&self) -> InOrderIndex {
// If the current index corresponds to a node in a left tree, to go up a level it is
// required to add the number of nodes of the right sibling, analogously if the node is a
// right child, going up requires subtracting the number of nodes in its left subtree.
//
// Both of the above operations can be performed by bitwise manipulation. Below the mask
// sets the number of trailing zeros to be equal the new level of the index, and the bit
// marks the parent.
let target = self.level() + 1;
let bit = 1 << target;
let mask = bit - 1;
let idx = self.idx ^ (self.idx & mask);
InOrderIndex { idx: idx | bit }
}
/// Returns the index of the sibling node.
pub fn sibling(&self) -> InOrderIndex {
let parent = self.parent();
if *self > parent {
parent.left_child()
} else {
parent.right_child()
}
}
/// Returns the inner value of this [InOrderIndex].
pub fn inner(&self) -> u64 {
self.idx as u64
}
}
impl Serializable for InOrderIndex {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
target.write_usize(self.idx);
}
}
impl Deserializable for InOrderIndex {
fn read_from<R: winter_utils::ByteReader>(
source: &mut R,
) -> Result<Self, winter_utils::DeserializationError> {
let idx = source.read_usize()?;
Ok(InOrderIndex { idx })
}
}
// CONVERSIONS FROM IN-ORDER INDEX
// ------------------------------------------------------------------------------------------------
impl From<InOrderIndex> for u64 {
fn from(index: InOrderIndex) -> Self {
index.idx as u64
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod test {
use proptest::prelude::*;
use winter_utils::{Deserializable, Serializable};
use super::InOrderIndex;
proptest! {
#[test]
fn proptest_inorder_index_random(count in 1..1000usize) {
let left_pos = count * 2;
let right_pos = count * 2 + 1;
let left = InOrderIndex::from_leaf_pos(left_pos);
let right = InOrderIndex::from_leaf_pos(right_pos);
assert!(left.is_leaf());
assert!(right.is_leaf());
assert_eq!(left.parent(), right.parent());
assert_eq!(left.parent().right_child(), right);
assert_eq!(left, right.parent().left_child());
assert_eq!(left.sibling(), right);
assert_eq!(left, right.sibling());
}
}
#[test]
fn test_inorder_index_basic() {
let left = InOrderIndex::from_leaf_pos(0);
let right = InOrderIndex::from_leaf_pos(1);
assert!(left.is_leaf());
assert!(right.is_leaf());
assert_eq!(left.parent(), right.parent());
assert_eq!(left.parent().right_child(), right);
assert_eq!(left, right.parent().left_child());
assert_eq!(left.sibling(), right);
assert_eq!(left, right.sibling());
}
#[test]
fn test_inorder_index_serialization() {
let index = InOrderIndex::from_leaf_pos(5);
let bytes = index.to_bytes();
let index2 = InOrderIndex::read_from_bytes(&bytes).unwrap();
assert_eq!(index, index2);
}
}

View file

@ -1,67 +0,0 @@
mod bit;
mod delta;
mod error;
mod full;
mod inorder;
mod partial;
mod peaks;
mod proof;
#[cfg(test)]
mod tests;
// REEXPORTS
// ================================================================================================
pub use delta::MmrDelta;
pub use error::MmrError;
pub use full::Mmr;
pub use inorder::InOrderIndex;
pub use partial::PartialMmr;
pub use peaks::MmrPeaks;
pub use proof::MmrProof;
use super::{Felt, Rpo256, RpoDigest, Word};
// UTILITIES
// ===============================================================================================
/// Given a 0-indexed leaf position and the current forest, return the tree number responsible for
/// the position.
///
/// Note:
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
/// the tree. Because the root element is not part of the proof, $p$ is the length of the
/// authentication path. $2^p$ is equal to the number of leaves in this particular tree. and
/// $2^(p+1)-1$ corresponds to size of the tree.
const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
if pos >= forest {
None
} else {
// - each bit in the forest is a unique tree and the bit position its power-of-two size
// - each tree owns a consecutive range of positions equal to its size from left-to-right
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
// `k_1` is the second highest bit, so on.
// - this means the highest bits work as a category marker, and the position is owned by the
// first tree which doesn't share a high bit with the position
let before = forest & pos;
let after = forest ^ before;
let tree = after.ilog2();
Some(tree)
}
}
/// Return the total number of nodes of a given forest
///
/// Panics:
///
/// This will panic if the forest has size greater than `usize::MAX / 2`
const fn nodes_in_forest(forest: usize) -> usize {
// - the size of a perfect binary tree is $2^{k+1}-1$ or $2*2^k-1$
// - the forest represents the sum of $2^k$ so a single multiplication is necessary
// - the number of `-1` is the same as the number of trees, which is the same as the number
// bits set
let tree_count = forest.count_ones() as usize;
forest * 2 - tree_count
}

View file

@ -1,952 +0,0 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use winter_utils::{Deserializable, Serializable};
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
use crate::merkle::{
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
};
// TYPE ALIASES
// ================================================================================================
type NodeMap = BTreeMap<InOrderIndex, RpoDigest>;
// PARTIAL MERKLE MOUNTAIN RANGE
// ================================================================================================
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
/// authentication paths for a subset of the elements in a full MMR.
///
/// This structure store only the authentication path for a value, the value itself is stored
/// separately.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartialMmr {
/// The version of the MMR.
///
/// This value serves the following purposes:
///
/// - The forest is a counter for the total number of elements in the MMR.
/// - Since the MMR is an append-only structure, every change to it causes a change to the
/// `forest`, so this value has a dual purpose as a version tag.
/// - The bits in the forest also corresponds to the count and size of every perfect binary
/// tree that composes the MMR structure, which server to compute indexes and perform
/// validation.
pub(crate) forest: usize,
/// The MMR peaks.
///
/// The peaks are used for two reasons:
///
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
/// elements are tracked.
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
/// peaks are used as the left hand.
///
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
/// leaves, starting from the peak with most children, to the one with least.
pub(crate) peaks: Vec<RpoDigest>,
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
///
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required to
/// construct their authentication paths. This property is used to detect when elements can be
/// safely removed, because they are no longer required to authenticate any element in the
/// [PartialMmr].
///
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
/// trees in the MMR can be represented without rewrites of the indexes.
pub(crate) nodes: NodeMap,
/// Flag indicating if the odd element should be tracked.
///
/// This flag is necessary because the sibling of the odd doesn't exist yet, so it can not be
/// added into `nodes` to signal the value is being tracked.
pub(crate) track_latest: bool,
}
impl PartialMmr {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [PartialMmr] instantiated from the specified peaks.
pub fn from_peaks(peaks: MmrPeaks) -> Self {
let forest = peaks.num_leaves();
let peaks = peaks.into();
let nodes = BTreeMap::new();
let track_latest = false;
Self { forest, peaks, nodes, track_latest }
}
/// Returns a new [PartialMmr] instantiated from the specified components.
///
/// This constructor does not check the consistency between peaks and nodes. If the specified
/// peaks are nodes are inconsistent, the returned partial MMR may exhibit undefined behavior.
pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
let forest = peaks.num_leaves();
let peaks = peaks.into();
Self { forest, peaks, nodes, track_latest }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the current `forest` of this [PartialMmr].
///
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
/// underlying MMR.
pub fn forest(&self) -> usize {
self.forest
}
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
pub fn num_leaves(&self) -> usize {
self.forest
}
/// Returns the peaks of the MMR for this [PartialMmr].
pub fn peaks(&self) -> MmrPeaks {
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
// correctly
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
}
/// Returns true if this partial MMR tracks an authentication path for the leaf at the
/// specified position.
pub fn is_tracked(&self, pos: usize) -> bool {
if pos >= self.forest {
return false;
} else if pos == self.forest - 1 && self.forest & 1 != 0 {
// if the number of leaves in the MMR is odd and the position is for the last leaf
// whether the leaf is tracked is defined by the `track_latest` flag
return self.track_latest;
}
let leaf_index = InOrderIndex::from_leaf_pos(pos);
self.is_tracked_node(&leaf_index)
}
/// Given a leaf position, returns the Merkle path to its corresponding peak, or None if this
/// partial MMR does not track an authentication paths for the specified leaf.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if the specified position is greater-or-equal than the number of leaves
/// in the underlying MMR.
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
let depth = tree_bit as usize;
let mut nodes = Vec::with_capacity(depth);
let mut idx = InOrderIndex::from_leaf_pos(pos);
while let Some(node) = self.nodes.get(&idx.sibling()) {
nodes.push(*node);
idx = idx.parent();
}
// If there are nodes then the path must be complete, otherwise it is a bug
debug_assert!(nodes.is_empty() || nodes.len() == depth);
if nodes.len() != depth {
// The requested `pos` is not being tracked.
Ok(None)
} else {
Ok(Some(MmrProof {
forest: self.forest,
position: pos,
merkle_path: MerklePath::new(nodes),
}))
}
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
self.nodes.iter()
}
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
///
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
/// is silently ignored.
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
&'a self,
mut leaves: I,
) -> impl Iterator<Item = InnerNodeInfo> + 'a {
let stack = if let Some((pos, leaf)) = leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos);
vec![(idx, leaf)]
} else {
Vec::new()
};
InnerNodeIterator {
nodes: &self.nodes,
leaves,
stack,
seen_nodes: BTreeSet::new(),
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Adds a new peak and optionally track it. Returns a vector of the authentication nodes
/// inserted into this [PartialMmr] as a result of this operation.
///
/// When `track` is `true` the new leaf is tracked.
pub fn add(&mut self, leaf: RpoDigest, track: bool) -> Vec<(InOrderIndex, RpoDigest)> {
self.forest += 1;
let merges = self.forest.trailing_zeros() as usize;
let mut new_nodes = Vec::with_capacity(merges);
let peak = if merges == 0 {
self.track_latest = track;
leaf
} else {
let mut track_right = track;
let mut track_left = self.track_latest;
let mut right = leaf;
let mut right_idx = forest_to_rightmost_index(self.forest);
for _ in 0..merges {
let left = self.peaks.pop().expect("Missing peak");
let left_idx = right_idx.sibling();
if track_right {
let old = self.nodes.insert(left_idx, left);
new_nodes.push((left_idx, left));
debug_assert!(
old.is_none(),
"Idx {:?} already contained an element {:?}",
left_idx,
old
);
};
if track_left {
let old = self.nodes.insert(right_idx, right);
new_nodes.push((right_idx, right));
debug_assert!(
old.is_none(),
"Idx {:?} already contained an element {:?}",
right_idx,
old
);
};
// Update state for the next iteration.
// --------------------------------------------------------------------------------
// This layer is merged, go up one layer.
right_idx = right_idx.parent();
// Merge the current layer. The result is either the right element of the next
// merge, or a new peak.
right = Rpo256::merge(&[left, right]);
// This iteration merged the left and right nodes, the new value is always used as
// the next iteration's right node. Therefore the tracking flags of this iteration
// have to be merged into the right side only.
track_right = track_right || track_left;
// On the next iteration, a peak will be merged. If any of its children are tracked,
// then we have to track the left side
track_left = self.is_tracked_node(&right_idx.sibling());
}
right
};
self.peaks.push(peak);
new_nodes
}
/// Adds the authentication path represented by [MerklePath] if it is valid.
///
/// The `leaf_pos` refers to the global position of the leaf in the MMR, these are 0-indexed
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
/// this value corresponds to the values used in the MMR structure.
///
/// The `leaf` corresponds to the value at `leaf_pos`, and `path` is the authentication path for
/// that element up to its corresponding Mmr peak. The `leaf` is only used to compute the root
/// from the authentication path to valid the data, only the authentication data is saved in
/// the structure. If the value is required it should be stored out-of-band.
pub fn track(
&mut self,
leaf_pos: usize,
leaf: RpoDigest,
path: &MerklePath,
) -> Result<(), MmrError> {
// Checks there is a tree with same depth as the authentication path, if not the path is
// invalid.
let tree = 1 << path.depth();
if tree & self.forest == 0 {
return Err(MmrError::UnknownPeak(path.depth()));
};
if leaf_pos + 1 == self.forest
&& path.depth() == 0
&& self.peaks.last().is_some_and(|v| *v == leaf)
{
self.track_latest = true;
return Ok(());
}
// ignore the trees smaller than the target (these elements are position after the current
// target and don't affect the target leaf_pos)
let target_forest = self.forest ^ (self.forest & (tree - 1));
let peak_pos = (target_forest.count_ones() - 1) as usize;
// translate from mmr leaf_pos to merkle path
let path_idx = leaf_pos - (target_forest ^ tree);
// Compute the root of the authentication path, and check it matches the current version of
// the PartialMmr.
let computed = path
.compute_root(path_idx as u64, leaf)
.map_err(MmrError::MerkleRootComputationFailed)?;
if self.peaks[peak_pos] != computed {
return Err(MmrError::PeakPathMismatch);
}
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
for leaf in path.nodes() {
self.nodes.insert(idx.sibling(), *leaf);
idx = idx.parent();
}
Ok(())
}
/// Removes a leaf of the [PartialMmr] and the unused nodes from the authentication path.
///
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
pub fn untrack(&mut self, leaf_pos: usize) {
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
self.nodes.remove(&idx.sibling());
// `idx` represent the element that can be computed by the authentication path, because
// these elements can be computed they are not saved for the authentication of the current
// target. In other words, if the idx is present it was added for the authentication of
// another element, and no more elements should be removed otherwise it would remove that
// element's authentication data.
while !self.nodes.contains_key(&idx) {
idx = idx.parent();
self.nodes.remove(&idx.sibling());
}
}
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
/// inserted into the partial MMR.
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
if delta.forest < self.forest {
return Err(MmrError::InvalidPeaks(format!(
"forest of mmr delta {} is less than current forest {}",
delta.forest, self.forest
)));
}
let mut inserted_nodes = Vec::new();
if delta.forest == self.forest {
if !delta.data.is_empty() {
return Err(MmrError::InvalidUpdate);
}
return Ok(inserted_nodes);
}
// find the tree merges
let changes = self.forest ^ delta.forest;
let largest = 1 << changes.ilog2();
let merges = self.forest & (largest - 1);
debug_assert!(
!self.track_latest || (merges & 1) == 1,
"if there is an odd element, a merge is required"
);
// count the number elements needed to produce largest from the current state
let (merge_count, new_peaks) = if merges != 0 {
let depth = largest.trailing_zeros();
let skipped = merges.trailing_zeros();
let computed = merges.count_ones() - 1;
let merge_count = depth - skipped - computed;
let new_peaks = delta.forest & (largest - 1);
(merge_count, new_peaks)
} else {
(0, changes)
};
// verify the delta size
if (delta.data.len() as u32) != merge_count + new_peaks.count_ones() {
return Err(MmrError::InvalidUpdate);
}
// keeps track of how many data elements from the update have been consumed
let mut update_count = 0;
if merges != 0 {
// starts at the smallest peak and follows the merged peaks
let mut peak_idx = forest_to_root_index(self.forest);
// match order of the update data while applying it
self.peaks.reverse();
// set to true when the data is needed for authentication paths updates
let mut track = self.track_latest;
self.track_latest = false;
let mut peak_count = 0;
let mut target = 1 << merges.trailing_zeros();
let mut new = delta.data[0];
update_count += 1;
while target < largest {
// check if either the left or right subtrees have saved for authentication paths.
// If so, turn tracking on to update those paths.
if target != 1 && !track {
track = self.is_tracked_node(&peak_idx);
}
// update data only contains the nodes from the right subtrees, left nodes are
// either previously known peaks or computed values
let (left, right) = if target & merges != 0 {
let peak = self.peaks[peak_count];
let sibling_idx = peak_idx.sibling();
// if the sibling peak is tracked, add this peaks to the set of
// authentication nodes
if self.is_tracked_node(&sibling_idx) {
self.nodes.insert(peak_idx, new);
inserted_nodes.push((peak_idx, new));
}
peak_count += 1;
(peak, new)
} else {
let update = delta.data[update_count];
update_count += 1;
(new, update)
};
if track {
let sibling_idx = peak_idx.sibling();
if peak_idx.is_left_child() {
self.nodes.insert(sibling_idx, right);
inserted_nodes.push((sibling_idx, right));
} else {
self.nodes.insert(sibling_idx, left);
inserted_nodes.push((sibling_idx, left));
}
}
peak_idx = peak_idx.parent();
new = Rpo256::merge(&[left, right]);
target <<= 1;
}
debug_assert!(peak_count == (merges.count_ones() as usize));
// restore the peaks order
self.peaks.reverse();
// remove the merged peaks
self.peaks.truncate(self.peaks.len() - peak_count);
// add the newly computed peak, the result of the merges
self.peaks.push(new);
}
// The rest of the update data is composed of peaks. None of these elements can contain
// tracked elements because the peaks were unknown, and it is not possible to add elements
// for tacking without authenticating it to a peak.
self.peaks.extend_from_slice(&delta.data[update_count..]);
self.forest = delta.forest;
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
Ok(inserted_nodes)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
/// index.
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
if node_index.is_leaf() {
self.nodes.contains_key(&node_index.sibling())
} else {
let left_child = node_index.left_child();
let right_child = node_index.right_child();
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
}
}
}
// CONVERSIONS
// ================================================================================================
impl From<MmrPeaks> for PartialMmr {
fn from(peaks: MmrPeaks) -> Self {
Self::from_peaks(peaks)
}
}
impl From<PartialMmr> for MmrPeaks {
fn from(partial_mmr: PartialMmr) -> Self {
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
// matches the number of peaks, as required by the [MmrPeaks]
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks).unwrap()
}
}
impl From<&MmrPeaks> for PartialMmr {
fn from(peaks: &MmrPeaks) -> Self {
Self::from_peaks(peaks.clone())
}
}
impl From<&PartialMmr> for MmrPeaks {
fn from(partial_mmr: &PartialMmr) -> Self {
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
// matches the number of peaks, as required by the [MmrPeaks]
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks.clone()).unwrap()
}
}
// ITERATORS
// ================================================================================================
/// An iterator over every inner node of the [PartialMmr].
pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
nodes: &'a NodeMap,
leaves: I,
stack: Vec<(InOrderIndex, RpoDigest)>,
seen_nodes: BTreeSet<InOrderIndex>,
}
impl<I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'_, I> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
while let Some((idx, node)) = self.stack.pop() {
let parent_idx = idx.parent();
let new_node = self.seen_nodes.insert(parent_idx);
// if we haven't seen this node's parent before, and the node has a sibling, return
// the inner node defined by the parent of this node, and move up the branch
if new_node {
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
let (left, right) = if parent_idx.left_child() == idx {
(node, *sibling)
} else {
(*sibling, node)
};
let parent = Rpo256::merge(&[left, right]);
let inner_node = InnerNodeInfo { value: parent, left, right };
self.stack.push((parent_idx, parent));
return Some(inner_node);
}
}
// the previous leaf has been processed, try to process the next leaf
if let Some((pos, leaf)) = self.leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos);
self.stack.push((idx, leaf));
}
}
None
}
}
impl Serializable for PartialMmr {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.forest.write_into(target);
self.peaks.write_into(target);
self.nodes.write_into(target);
target.write_bool(self.track_latest);
}
}
impl Deserializable for PartialMmr {
fn read_from<R: winter_utils::ByteReader>(
source: &mut R,
) -> Result<Self, winter_utils::DeserializationError> {
let forest = usize::read_from(source)?;
let peaks = Vec::<RpoDigest>::read_from(source)?;
let nodes = NodeMap::read_from(source)?;
let track_latest = source.read_bool()?;
Ok(Self { forest, peaks, nodes, track_latest })
}
}
// UTILS
// ================================================================================================
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
/// in it.
fn forest_to_root_index(forest: usize) -> InOrderIndex {
// Count total size of all trees in the forest.
let nodes = nodes_in_forest(forest);
// Add the count for the parent nodes that separate each tree. These are allocated but
// currently empty, and correspond to the nodes that will be used once the trees are merged.
let open_trees = (forest.count_ones() - 1) as usize;
// Remove the count of the right subtree of the target tree, target tree root index comes
// before the subtree for the in-order tree walk.
let right_subtree_count = ((1u32 << forest.trailing_zeros()) - 1) as usize;
let idx = nodes + open_trees - right_subtree_count;
InOrderIndex::new(idx.try_into().unwrap())
}
/// Given the description of a `forest`, returns the index of the right most element.
fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
// Count total size of all trees in the forest.
let nodes = nodes_in_forest(forest);
// Add the count for the parent nodes that separate each tree. These are allocated but
// currently empty, and correspond to the nodes that will be used once the trees are merged.
let open_trees = (forest.count_ones() - 1) as usize;
let idx = nodes + open_trees;
InOrderIndex::new(idx.try_into().unwrap())
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use alloc::{collections::BTreeSet, vec::Vec};
use winter_utils::{Deserializable, Serializable};
use super::{
forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr,
RpoDigest,
};
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
const LEAVES: [RpoDigest; 7] = [
int_to_node(0),
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
];
#[test]
fn test_forest_to_root_index() {
fn idx(pos: usize) -> InOrderIndex {
InOrderIndex::new(pos.try_into().unwrap())
}
// When there is a single tree in the forest, the index is equivalent to the number of
// leaves in that tree, which is `2^n`.
assert_eq!(forest_to_root_index(0b0001), idx(1));
assert_eq!(forest_to_root_index(0b0010), idx(2));
assert_eq!(forest_to_root_index(0b0100), idx(4));
assert_eq!(forest_to_root_index(0b1000), idx(8));
assert_eq!(forest_to_root_index(0b0011), idx(5));
assert_eq!(forest_to_root_index(0b0101), idx(9));
assert_eq!(forest_to_root_index(0b1001), idx(17));
assert_eq!(forest_to_root_index(0b0111), idx(13));
assert_eq!(forest_to_root_index(0b1011), idx(21));
assert_eq!(forest_to_root_index(0b1111), idx(29));
assert_eq!(forest_to_root_index(0b0110), idx(10));
assert_eq!(forest_to_root_index(0b1010), idx(18));
assert_eq!(forest_to_root_index(0b1100), idx(20));
assert_eq!(forest_to_root_index(0b1110), idx(26));
}
#[test]
fn test_forest_to_rightmost_index() {
fn idx(pos: usize) -> InOrderIndex {
InOrderIndex::new(pos.try_into().unwrap())
}
for forest in 1..256 {
assert!(forest_to_rightmost_index(forest).inner() % 2 == 1, "Leaves are always odd");
}
assert_eq!(forest_to_rightmost_index(0b0001), idx(1));
assert_eq!(forest_to_rightmost_index(0b0010), idx(3));
assert_eq!(forest_to_rightmost_index(0b0011), idx(5));
assert_eq!(forest_to_rightmost_index(0b0100), idx(7));
assert_eq!(forest_to_rightmost_index(0b0101), idx(9));
assert_eq!(forest_to_rightmost_index(0b0110), idx(11));
assert_eq!(forest_to_rightmost_index(0b0111), idx(13));
assert_eq!(forest_to_rightmost_index(0b1000), idx(15));
assert_eq!(forest_to_rightmost_index(0b1001), idx(17));
assert_eq!(forest_to_rightmost_index(0b1010), idx(19));
assert_eq!(forest_to_rightmost_index(0b1011), idx(21));
assert_eq!(forest_to_rightmost_index(0b1100), idx(23));
assert_eq!(forest_to_rightmost_index(0b1101), idx(25));
assert_eq!(forest_to_rightmost_index(0b1110), idx(27));
assert_eq!(forest_to_rightmost_index(0b1111), idx(29));
}
#[test]
fn test_partial_mmr_apply_delta() {
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
let mut mmr = Mmr::default();
(0..10).for_each(|i| mmr.add(int_to_node(i)));
let mut partial_mmr: PartialMmr = mmr.peaks().into();
// add authentication path for position 1 and 8
{
let node = mmr.get(1).unwrap();
let proof = mmr.open(1).unwrap();
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
}
{
let node = mmr.get(8).unwrap();
let proof = mmr.open(8).unwrap();
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
}
// add 2 more nodes into the MMR and validate apply_delta()
(10..12).for_each(|i| mmr.add(int_to_node(i)));
validate_apply_delta(&mmr, &mut partial_mmr);
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
mmr.add(int_to_node(12));
validate_apply_delta(&mmr, &mut partial_mmr);
{
let node = mmr.get(12).unwrap();
let proof = mmr.open(12).unwrap();
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
assert!(partial_mmr.track_latest);
}
// by this point we are tracking authentication paths for positions: 1, 8, and 12
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
(13..16).for_each(|i| mmr.add(int_to_node(i)));
validate_apply_delta(&mmr, &mut partial_mmr);
}
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
let tracked_leaves = partial
.nodes
.iter()
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
.collect::<Vec<_>>();
let nodes_before = partial.nodes.clone();
// compute and apply delta
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
let nodes_delta = partial.apply(delta).unwrap();
// new peaks were computed correctly
assert_eq!(mmr.peaks(), partial.peaks());
let mut expected_nodes = nodes_before;
for (key, value) in nodes_delta {
// nodes should not be duplicated
assert!(expected_nodes.insert(key, value).is_none());
}
// new nodes should be a combination of original nodes and delta
assert_eq!(expected_nodes, partial.nodes);
// make sure tracked leaves open to the same proofs as in the underlying MMR
for index in tracked_leaves {
let index_value: u64 = index.into();
let pos = index_value / 2;
let proof1 = partial.open(pos as usize).unwrap().unwrap();
let proof2 = mmr.open(pos as usize).unwrap();
assert_eq!(proof1, proof2);
}
}
#[test]
fn test_partial_mmr_inner_nodes_iterator() {
// build the MMR
let mmr: Mmr = LEAVES.into();
let first_peak = mmr.peaks().peaks()[0];
// -- test single tree ----------------------------
// get path and node for position 1
let node1 = mmr.get(1).unwrap();
let proof1 = mmr.open(1).unwrap();
// create partial MMR and add authentication path to node at position 1
let mut partial_mmr: PartialMmr = mmr.peaks().into();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
// empty iterator should have no nodes
assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
// build Merkle store from authentication paths in partial MMR
let mut store: MerkleStore = MerkleStore::new();
store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
let index1 = NodeIndex::new(2, 1).unwrap();
let path1 = store.get_path(first_peak, index1).unwrap().path;
assert_eq!(path1, proof1.merkle_path);
// -- test no duplicates --------------------------
// build the partial MMR
let mut partial_mmr: PartialMmr = mmr.peaks().into();
let node0 = mmr.get(0).unwrap();
let proof0 = mmr.open(0).unwrap();
let node2 = mmr.get(2).unwrap();
let proof2 = mmr.open(2).unwrap();
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
// make sure there are no duplicates
let leaves = [(0, node0), (1, node1), (2, node2)];
let mut nodes = BTreeSet::new();
for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
assert!(nodes.insert(node.value));
}
// and also that the store is still be built correctly
store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
let index0 = NodeIndex::new(2, 0).unwrap();
let index1 = NodeIndex::new(2, 1).unwrap();
let index2 = NodeIndex::new(2, 2).unwrap();
let path0 = store.get_path(first_peak, index0).unwrap().path;
let path1 = store.get_path(first_peak, index1).unwrap().path;
let path2 = store.get_path(first_peak, index2).unwrap().path;
assert_eq!(path0, proof0.merkle_path);
assert_eq!(path1, proof1.merkle_path);
assert_eq!(path2, proof2.merkle_path);
// -- test multiple trees -------------------------
// build the partial MMR
let mut partial_mmr: PartialMmr = mmr.peaks().into();
let node5 = mmr.get(5).unwrap();
let proof5 = mmr.open(5).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
// build Merkle store from authentication paths in partial MMR
let mut store: MerkleStore = MerkleStore::new();
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
let index1 = NodeIndex::new(2, 1).unwrap();
let index5 = NodeIndex::new(1, 1).unwrap();
let second_peak = mmr.peaks().peaks()[1];
let path1 = store.get_path(first_peak, index1).unwrap().path;
let path5 = store.get_path(second_peak, index5).unwrap().path;
assert_eq!(path1, proof1.merkle_path);
assert_eq!(path5, proof5.merkle_path);
}
#[test]
fn test_partial_mmr_add_without_track() {
let mut mmr = Mmr::default();
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
for el in (0..256).map(int_to_node) {
mmr.add(el);
partial_mmr.add(el, false);
assert_eq!(mmr.peaks(), partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest());
}
}
#[test]
fn test_partial_mmr_add_with_track() {
let mut mmr = Mmr::default();
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
for i in 0..256 {
let el = int_to_node(i);
mmr.add(el);
partial_mmr.add(el, true);
assert_eq!(mmr.peaks(), partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest());
for pos in 0..i {
let mmr_proof = mmr.open(pos as usize).unwrap();
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
assert_eq!(mmr_proof, partialmmr_proof);
}
}
}
#[test]
fn test_partial_mmr_add_existing_track() {
let mut mmr = Mmr::from((0..7).map(int_to_node));
// derive a partial Mmr from it which tracks authentication path to leaf 5
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
let path_to_5 = mmr.open(5).unwrap().merkle_path;
let leaf_at_5 = mmr.get(5).unwrap();
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
// add a new leaf to both Mmr and partial Mmr
let leaf_at_7 = int_to_node(7);
mmr.add(leaf_at_7);
partial_mmr.add(leaf_at_7, false);
// the openings should be the same
assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
}
#[test]
fn test_partial_mmr_serialization() {
let mmr = Mmr::from((0..7).map(int_to_node));
let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
let bytes = partial_mmr.to_bytes();
let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
assert_eq!(partial_mmr, decoded);
}
}

View file

@ -1,162 +0,0 @@
use alloc::vec::Vec;
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
// MMR PEAKS
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrPeaks {
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
/// happens because the number of peaks goes up-and-down as the structure is used causing
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
/// has a power-of-two number of leaves there is a single peak.
///
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
/// bits in `num_leaves` conveniently encode the size of each individual tree.
///
/// Examples:
///
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number of peaks,
/// in this case there are 2 peaks. The 0-indexed least-significant position of the bit
/// determines the number of elements of a tree, so the rightmost tree has `2**0` elements
/// and the left most has `2**2`.
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the leftmost tree has
/// `2**3=8` elements, and the right most has `2**2=4` elements.
num_leaves: usize,
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
/// leaves, starting from the peak with most children, to the one with least.
///
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
peaks: Vec<RpoDigest>,
}
impl MmrPeaks {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
/// leaves in the underlying MMR.
///
/// # Errors
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
if num_leaves.count_ones() as usize != peaks.len() {
return Err(MmrError::InvalidPeaks(format!(
"number of one bits in leaves is {} which does not equal peak length {}",
num_leaves.count_ones(),
peaks.len()
)));
}
Ok(Self { num_leaves, peaks })
}
// ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns a count of leaves in the underlying MMR.
pub fn num_leaves(&self) -> usize {
self.num_leaves
}
/// Returns the number of peaks of the underlying MMR.
pub fn num_peaks(&self) -> usize {
self.peaks.len()
}
/// Returns the list of peaks of the underlying MMR.
pub fn peaks(&self) -> &[RpoDigest] {
&self.peaks
}
/// Returns the peak by the provided index.
///
/// # Errors
/// Returns an error if the provided peak index is greater or equal to the current number of
/// peaks in the Mmr.
pub fn get_peak(&self, peak_idx: usize) -> Result<&RpoDigest, MmrError> {
self.peaks
.get(peak_idx)
.ok_or(MmrError::PeakOutOfBounds { peak_idx, peaks_len: self.peaks.len() })
}
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
/// the underlying MMR.
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
(self.num_leaves, self.peaks)
}
/// Hashes the peaks.
///
/// The procedure will:
/// - Flatten and pad the peaks to a vector of Felts.
/// - Hash the vector of Felts.
pub fn hash_peaks(&self) -> RpoDigest {
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
}
/// Verifies the Merkle opening proof.
///
/// # Errors
/// Returns an error if:
/// - provided opening proof is invalid.
/// - Mmr root value computed using the provided leaf value differs from the actual one.
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> Result<(), MmrError> {
let root = self.get_peak(opening.peak_index())?;
opening
.merkle_path
.verify(opening.relative_pos() as u64, value, root)
.map_err(MmrError::InvalidMerklePath)
}
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
///
/// The procedure will:
/// - Flatten the vector of Words into a vector of Felts.
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
/// padding.
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of hashing.
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
let num_peaks = self.peaks.len();
// To achieve the padding rules above we calculate the length of the final vector.
// This is calculated as the number of field elements. Each peak is 4 field elements.
// The length is calculated as follows:
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires 64
// field elements.
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded
// and as such requires `num_peaks * 4` field elements.
let len = if num_peaks < 16 {
64
} else if num_peaks % 2 == 1 {
(num_peaks + 1) * 4
} else {
num_peaks * 4
};
let mut elements = Vec::with_capacity(len);
elements.extend_from_slice(
&self
.peaks
.as_slice()
.iter()
.map(|digest| digest.into())
.collect::<Vec<Word>>()
.concat(),
);
elements.resize(len, ZERO);
elements
}
}
impl From<MmrPeaks> for Vec<RpoDigest> {
fn from(peaks: MmrPeaks) -> Self {
peaks.peaks
}
}

View file

@ -1,106 +0,0 @@
/// The representation of a single Merkle path.
use super::super::MerklePath;
use super::{full::high_bitmask, leaf_to_corresponding_tree};
// MMR PROOF
// ================================================================================================
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrProof {
/// The state of the MMR when the MmrProof was created.
pub forest: usize,
/// The position of the leaf value on this MmrProof.
pub position: usize,
/// The Merkle opening, starting from the value's sibling up to and excluding the root of the
/// responsible tree.
pub merkle_path: MerklePath,
}
impl MmrProof {
/// Converts the leaf global position into a local position that can be used to verify the
/// merkle_path.
pub fn relative_pos(&self) -> usize {
let tree_bit = leaf_to_corresponding_tree(self.position, self.forest)
.expect("position must be part of the forest");
let forest_before = self.forest & high_bitmask(tree_bit + 1);
self.position - forest_before
}
/// Returns index of the MMR peak against which the Merkle path in this proof can be verified.
pub fn peak_index(&self) -> usize {
let root = leaf_to_corresponding_tree(self.position, self.forest)
.expect("position must be part of the forest");
let smaller_peak_mask = 2_usize.pow(root) as usize - 1;
let num_smaller_peaks = (self.forest & smaller_peak_mask).count_ones();
(self.forest.count_ones() - num_smaller_peaks - 1) as usize
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{MerklePath, MmrProof};
#[test]
fn test_peak_index() {
// --- single peak forest ---------------------------------------------
let forest = 11;
// the first 4 leaves belong to peak 0
for position in 0..8 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// --- forest with non-consecutive peaks ------------------------------
let forest = 11;
// the first 8 leaves belong to peak 0
for position in 0..8 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// the next 2 leaves belong to peak 1
for position in 8..10 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 1);
}
// the last leaf is the peak 2
let proof = make_dummy_proof(forest, 10);
assert_eq!(proof.peak_index(), 2);
// --- forest with consecutive peaks ----------------------------------
let forest = 7;
// the first 4 leaves belong to peak 0
for position in 0..4 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// the next 2 leaves belong to peak 1
for position in 4..6 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 1);
}
// the last leaf is the peak 2
let proof = make_dummy_proof(forest, 6);
assert_eq!(proof.peak_index(), 2);
}
fn make_dummy_proof(forest: usize, position: usize) -> MmrProof {
MmrProof {
forest,
position,
merkle_path: MerklePath::default(),
}
}
}

View file

@ -1,890 +0,0 @@
use alloc::vec::Vec;
use super::{
super::{InnerNodeInfo, Rpo256, RpoDigest},
bit::TrueBitPositionIterator,
full::high_bitmask,
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
};
use crate::{
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
Felt, Word,
};
#[test]
fn test_position_equal_or_higher_than_leafs_is_never_contained() {
let empty_forest = 0;
for pos in 1..1024 {
// pos is index, 0 based
// tree is a length counter, 1 based
// so a valid pos is always smaller, not equal, to tree
assert_eq!(leaf_to_corresponding_tree(pos, pos), None);
assert_eq!(leaf_to_corresponding_tree(pos, pos - 1), None);
// and empty forest has no trees, so no position is valid
assert_eq!(leaf_to_corresponding_tree(pos, empty_forest), None);
}
}
#[test]
fn test_position_zero_is_always_contained_by_the_highest_tree() {
for leaves in 1..1024usize {
let tree = leaves.ilog2();
assert_eq!(leaf_to_corresponding_tree(0, leaves), Some(tree));
}
}
#[test]
fn test_leaf_to_corresponding_tree() {
assert_eq!(leaf_to_corresponding_tree(0, 0b0001), Some(0));
assert_eq!(leaf_to_corresponding_tree(0, 0b0010), Some(1));
assert_eq!(leaf_to_corresponding_tree(0, 0b0011), Some(1));
assert_eq!(leaf_to_corresponding_tree(0, 0b1011), Some(3));
// position one is always owned by the left-most tree
assert_eq!(leaf_to_corresponding_tree(1, 0b0010), Some(1));
assert_eq!(leaf_to_corresponding_tree(1, 0b0011), Some(1));
assert_eq!(leaf_to_corresponding_tree(1, 0b1011), Some(3));
// position two starts as its own root, and then it is merged with the left-most tree
assert_eq!(leaf_to_corresponding_tree(2, 0b0011), Some(0));
assert_eq!(leaf_to_corresponding_tree(2, 0b0100), Some(2));
assert_eq!(leaf_to_corresponding_tree(2, 0b1011), Some(3));
// position tree is merged on the left-most tree
assert_eq!(leaf_to_corresponding_tree(3, 0b0011), None);
assert_eq!(leaf_to_corresponding_tree(3, 0b0100), Some(2));
assert_eq!(leaf_to_corresponding_tree(3, 0b1011), Some(3));
assert_eq!(leaf_to_corresponding_tree(4, 0b0101), Some(0));
assert_eq!(leaf_to_corresponding_tree(4, 0b0110), Some(1));
assert_eq!(leaf_to_corresponding_tree(4, 0b0111), Some(1));
assert_eq!(leaf_to_corresponding_tree(4, 0b1000), Some(3));
assert_eq!(leaf_to_corresponding_tree(12, 0b01101), Some(0));
assert_eq!(leaf_to_corresponding_tree(12, 0b01110), Some(1));
assert_eq!(leaf_to_corresponding_tree(12, 0b01111), Some(1));
assert_eq!(leaf_to_corresponding_tree(12, 0b10000), Some(4));
}
#[test]
fn test_high_bitmask() {
assert_eq!(high_bitmask(0), usize::MAX);
assert_eq!(high_bitmask(1), usize::MAX << 1);
assert_eq!(high_bitmask(usize::BITS - 2), 0b11usize.rotate_right(2));
assert_eq!(high_bitmask(usize::BITS - 1), 0b1usize.rotate_right(1));
assert_eq!(high_bitmask(usize::BITS), 0, "overflow should be handled");
}
#[test]
fn test_nodes_in_forest() {
assert_eq!(nodes_in_forest(0b0000), 0);
assert_eq!(nodes_in_forest(0b0001), 1);
assert_eq!(nodes_in_forest(0b0010), 3);
assert_eq!(nodes_in_forest(0b0011), 4);
assert_eq!(nodes_in_forest(0b0100), 7);
assert_eq!(nodes_in_forest(0b0101), 8);
assert_eq!(nodes_in_forest(0b0110), 10);
assert_eq!(nodes_in_forest(0b0111), 11);
assert_eq!(nodes_in_forest(0b1000), 15);
assert_eq!(nodes_in_forest(0b1001), 16);
assert_eq!(nodes_in_forest(0b1010), 18);
assert_eq!(nodes_in_forest(0b1011), 19);
}
#[test]
fn test_nodes_in_forest_single_bit() {
assert_eq!(nodes_in_forest(2usize.pow(0)), 2usize.pow(1) - 1);
assert_eq!(nodes_in_forest(2usize.pow(1)), 2usize.pow(2) - 1);
assert_eq!(nodes_in_forest(2usize.pow(2)), 2usize.pow(3) - 1);
assert_eq!(nodes_in_forest(2usize.pow(3)), 2usize.pow(4) - 1);
for bit in 0..(usize::BITS - 1) {
let size = 2usize.pow(bit + 1) - 1;
assert_eq!(nodes_in_forest(1usize << bit), size);
}
}
const LEAVES: [RpoDigest; 7] = [
int_to_node(0),
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
];
#[test]
fn test_mmr_simple() {
let mut postorder = vec![
LEAVES[0],
LEAVES[1],
merge(LEAVES[0], LEAVES[1]),
LEAVES[2],
LEAVES[3],
merge(LEAVES[2], LEAVES[3]),
];
postorder.push(merge(postorder[2], postorder[5]));
postorder.push(LEAVES[4]);
postorder.push(LEAVES[5]);
postorder.push(merge(LEAVES[4], LEAVES[5]));
postorder.push(LEAVES[6]);
let mut mmr = Mmr::new();
assert_eq!(mmr.forest(), 0);
assert_eq!(mmr.nodes.len(), 0);
mmr.add(LEAVES[0]);
assert_eq!(mmr.forest(), 1);
assert_eq!(mmr.nodes.len(), 1);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 1);
assert_eq!(acc.peaks(), &[postorder[0]]);
mmr.add(LEAVES[1]);
assert_eq!(mmr.forest(), 2);
assert_eq!(mmr.nodes.len(), 3);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 2);
assert_eq!(acc.peaks(), &[postorder[2]]);
mmr.add(LEAVES[2]);
assert_eq!(mmr.forest(), 3);
assert_eq!(mmr.nodes.len(), 4);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 3);
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
mmr.add(LEAVES[3]);
assert_eq!(mmr.forest(), 4);
assert_eq!(mmr.nodes.len(), 7);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 4);
assert_eq!(acc.peaks(), &[postorder[6]]);
mmr.add(LEAVES[4]);
assert_eq!(mmr.forest(), 5);
assert_eq!(mmr.nodes.len(), 8);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 5);
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
mmr.add(LEAVES[5]);
assert_eq!(mmr.forest(), 6);
assert_eq!(mmr.nodes.len(), 10);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 6);
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
mmr.add(LEAVES[6]);
assert_eq!(mmr.forest(), 7);
assert_eq!(mmr.nodes.len(), 11);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 7);
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
}
#[test]
fn test_mmr_open() {
let mmr: Mmr = LEAVES.into();
let h01 = merge(LEAVES[0], LEAVES[1]);
let h23 = merge(LEAVES[2], LEAVES[3]);
// node at pos 7 is the root
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
// node at pos 6 is the root
let empty: MerklePath = MerklePath::new(vec![]);
let opening = mmr
.open(6)
.expect("Element 6 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, empty);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 6);
mmr.peaks().verify(LEAVES[6], opening).unwrap();
// nodes 4,5 are depth 1
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
let opening = mmr
.open(5)
.expect("Element 5 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 5);
mmr.peaks().verify(LEAVES[5], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
let opening = mmr
.open(4)
.expect("Element 4 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 4);
mmr.peaks().verify(LEAVES[4], opening).unwrap();
// nodes 0,1,2,3 are detph 2
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
let opening = mmr
.open(3)
.expect("Element 3 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 3);
mmr.peaks().verify(LEAVES[3], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
let opening = mmr
.open(2)
.expect("Element 2 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 2);
mmr.peaks().verify(LEAVES[2], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
let opening = mmr
.open(1)
.expect("Element 1 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 1);
mmr.peaks().verify(LEAVES[1], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
let opening = mmr
.open(0)
.expect("Element 0 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 0);
mmr.peaks().verify(LEAVES[0], opening).unwrap();
}
#[test]
fn test_mmr_open_older_version() {
let mmr: Mmr = LEAVES.into();
fn is_even(v: &usize) -> bool {
v & 1 == 0
}
// merkle path of a node is empty if there are no elements to pair with it
for pos in (0..mmr.forest()).filter(is_even) {
let forest = pos + 1;
let proof = mmr.open_at(pos, forest).unwrap();
assert_eq!(proof.forest, forest);
assert_eq!(proof.merkle_path.nodes(), []);
assert_eq!(proof.position, pos);
}
// openings match that of a merkle tree
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
for forest in 4..=LEAVES.len() {
for pos in 0..4 {
let idx = NodeIndex::new(2, pos).unwrap();
let path = mtree.get_path(idx).unwrap();
let proof = mmr.open_at(pos as usize, forest).unwrap();
assert_eq!(path, proof.merkle_path);
}
}
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
for forest in 6..=LEAVES.len() {
for pos in 0..2 {
let idx = NodeIndex::new(1, pos).unwrap();
let path = mtree.get_path(idx).unwrap();
// account for the bigger tree with 4 elements
let mmr_pos = (pos + 4) as usize;
let proof = mmr.open_at(mmr_pos, forest).unwrap();
assert_eq!(path, proof.merkle_path);
}
}
}
/// Tests the openings of a simple Mmr with a single tree of depth 8.
#[test]
fn test_mmr_open_eight() {
let leaves = [
int_to_node(0),
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
];
let mtree: MerkleTree = leaves.as_slice().try_into().unwrap();
let forest = leaves.len();
let mmr: Mmr = leaves.into();
let root = mtree.root();
let position = 0;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 1;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 2;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 3;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 4;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 5;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 6;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 7;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
}
/// Tests the openings of Mmr with a 3 trees of depths 4, 2, and 1.
#[test]
fn test_mmr_open_seven() {
let mtree1: MerkleTree = LEAVES[..4].try_into().unwrap();
let mtree2: MerkleTree = LEAVES[4..6].try_into().unwrap();
let forest = LEAVES.len();
let mmr: Mmr = LEAVES.into();
let position = 0;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
let position = 1;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
let position = 2;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
let position = 3;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
let position = 4;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
let position = 5;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
let position = 6;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = [].as_ref().into();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
}
#[test]
fn test_mmr_get() {
let mmr: Mmr = LEAVES.into();
assert_eq!(mmr.get(0).unwrap(), LEAVES[0], "value at pos 0 must correspond");
assert_eq!(mmr.get(1).unwrap(), LEAVES[1], "value at pos 1 must correspond");
assert_eq!(mmr.get(2).unwrap(), LEAVES[2], "value at pos 2 must correspond");
assert_eq!(mmr.get(3).unwrap(), LEAVES[3], "value at pos 3 must correspond");
assert_eq!(mmr.get(4).unwrap(), LEAVES[4], "value at pos 4 must correspond");
assert_eq!(mmr.get(5).unwrap(), LEAVES[5], "value at pos 5 must correspond");
assert_eq!(mmr.get(6).unwrap(), LEAVES[6], "value at pos 6 must correspond");
assert!(mmr.get(7).is_err());
}
#[test]
fn test_mmr_invariants() {
let mut mmr = Mmr::new();
for v in 1..=1028 {
mmr.add(int_to_node(v));
let accumulator = mmr.peaks();
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
assert_eq!(
v as usize,
accumulator.num_leaves(),
"MMR and its accumulator must match leaves count"
);
assert_eq!(
accumulator.num_leaves().count_ones() as usize,
accumulator.peaks().len(),
"bits on leaves must match the number of peaks"
);
let expected_nodes: usize = TrueBitPositionIterator::new(mmr.forest())
.map(|bit_pos| nodes_in_forest(1 << bit_pos))
.sum();
assert_eq!(
expected_nodes,
mmr.nodes.len(),
"the sum of every tree size must be equal to the number of nodes in the MMR (forest: {:b})",
mmr.forest(),
);
}
}
#[test]
fn test_bit_position_iterator() {
assert_eq!(TrueBitPositionIterator::new(0).count(), 0);
assert_eq!(TrueBitPositionIterator::new(0).rev().count(), 0);
assert_eq!(TrueBitPositionIterator::new(1).collect::<Vec<u32>>(), vec![0]);
assert_eq!(TrueBitPositionIterator::new(1).rev().collect::<Vec<u32>>(), vec![0],);
assert_eq!(TrueBitPositionIterator::new(2).collect::<Vec<u32>>(), vec![1]);
assert_eq!(TrueBitPositionIterator::new(2).rev().collect::<Vec<u32>>(), vec![1],);
assert_eq!(TrueBitPositionIterator::new(3).collect::<Vec<u32>>(), vec![0, 1],);
assert_eq!(TrueBitPositionIterator::new(3).rev().collect::<Vec<u32>>(), vec![1, 0],);
assert_eq!(
TrueBitPositionIterator::new(0b11010101).collect::<Vec<u32>>(),
vec![0, 2, 4, 6, 7],
);
assert_eq!(
TrueBitPositionIterator::new(0b11010101).rev().collect::<Vec<u32>>(),
vec![7, 6, 4, 2, 0],
);
}
#[test]
fn test_mmr_inner_nodes() {
let mmr: Mmr = LEAVES.into();
let nodes: Vec<InnerNodeInfo> = mmr.inner_nodes().collect();
let h01 = Rpo256::merge(&[LEAVES[0], LEAVES[1]]);
let h23 = Rpo256::merge(&[LEAVES[2], LEAVES[3]]);
let h0123 = Rpo256::merge(&[h01, h23]);
let h45 = Rpo256::merge(&[LEAVES[4], LEAVES[5]]);
let postorder = vec![
InnerNodeInfo {
value: h01,
left: LEAVES[0],
right: LEAVES[1],
},
InnerNodeInfo {
value: h23,
left: LEAVES[2],
right: LEAVES[3],
},
InnerNodeInfo { value: h0123, left: h01, right: h23 },
InnerNodeInfo {
value: h45,
left: LEAVES[4],
right: LEAVES[5],
},
];
assert_eq!(postorder, nodes);
}
#[test]
fn test_mmr_peaks() {
let mmr: Mmr = LEAVES.into();
let forest = 0b0001;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
let forest = 0b0010;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
let forest = 0b0011;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
let forest = 0b0100;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
let forest = 0b0101;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
let forest = 0b0110;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
let forest = 0b0111;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
}
#[test]
fn test_mmr_hash_peaks() {
let mmr: Mmr = LEAVES.into();
let peaks = mmr.peaks();
let first_peak = Rpo256::merge(&[
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
Rpo256::merge(&[LEAVES[2], LEAVES[3]]),
]);
let second_peak = Rpo256::merge(&[LEAVES[4], LEAVES[5]]);
let third_peak = LEAVES[6];
// minimum length is 16
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
expected_peaks.resize(16, RpoDigest::default());
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
}
#[test]
fn test_mmr_peaks_hash_less_than_16() {
let mut peaks = Vec::new();
for i in 0..16 {
peaks.push(int_to_node(i));
let num_leaves = (1 << peaks.len()) - 1;
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
// minimum length is 16
let mut expected_peaks = peaks.clone();
expected_peaks.resize(16, RpoDigest::default());
assert_eq!(
accumulator.hash_peaks(),
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
);
}
}
#[test]
fn test_mmr_peaks_hash_odd() {
let peaks: Vec<_> = (0..=17).map(int_to_node).collect();
let num_leaves = (1 << peaks.len()) - 1;
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
// odd length bigger than 16 is padded to the next even number
let mut expected_peaks = peaks;
expected_peaks.resize(18, RpoDigest::default());
assert_eq!(
accumulator.hash_peaks(),
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
);
}
#[test]
fn test_mmr_delta() {
let mmr: Mmr = LEAVES.into();
let acc = mmr.peaks();
// original_forest can't have more elements
assert!(
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
"Can not provide updates for a newer Mmr"
);
// if the number of elements is the same there is no change
assert!(
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
"There are no updates for the same Mmr version"
);
// missing the last element added, which is itself a tree peak
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
// missing the sibling to complete the tree of depth 2, and the last element
assert_eq!(
mmr.get_delta(5, mmr.forest()).unwrap().data,
vec![LEAVES[5], acc.peaks()[2]],
"one sibling, one peak"
);
// missing the whole last two trees, only send the peaks
assert_eq!(
mmr.get_delta(4, mmr.forest()).unwrap().data,
vec![acc.peaks()[1], acc.peaks()[2]],
"two peaks"
);
// missing the sibling to complete the first tree, and the two last trees
assert_eq!(
mmr.get_delta(3, mmr.forest()).unwrap().data,
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
"one sibling, two peaks"
);
// missing half of the first tree, only send the computed element (not the leaves), and the new
// peaks
assert_eq!(
mmr.get_delta(2, mmr.forest()).unwrap().data,
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
"one sibling, two peaks"
);
assert_eq!(
mmr.get_delta(1, mmr.forest()).unwrap().data,
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
"one sibling, two peaks"
);
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
}
#[test]
fn test_mmr_delta_old_forest() {
let mmr: Mmr = LEAVES.into();
// from_forest must be smaller-or-equal to to_forest
for version in 1..=mmr.forest() {
assert!(mmr.get_delta(version + 1, version).is_err());
}
// when from_forest and to_forest are equal, there are no updates
for version in 1..=mmr.forest() {
let delta = mmr.get_delta(version, version).unwrap();
assert!(delta.data.is_empty());
assert_eq!(delta.forest, version);
}
// test update which merges the odd peak to the right
for count in 0..(mmr.forest() / 2) {
// *2 because every iteration tests a pair
// +1 because the Mmr is 1-indexed
let from_forest = (count * 2) + 1;
let to_forest = (count * 2) + 2;
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
// *2 because every iteration tests a pair
// +1 because sibling is the odd element
let sibling = (count * 2) + 1;
assert_eq!(delta.data, [LEAVES[sibling]]);
assert_eq!(delta.forest, to_forest);
}
let version = 4;
let delta = mmr.get_delta(1, version).unwrap();
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
assert_eq!(delta.forest, version);
let version = 5;
let delta = mmr.get_delta(1, version).unwrap();
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
assert_eq!(delta.forest, version);
}
#[test]
fn test_partial_mmr_simple() {
let mmr: Mmr = LEAVES.into();
let peaks = mmr.peaks();
let mut partial: PartialMmr = peaks.clone().into();
// check initial state of the partial mmr
assert_eq!(partial.peaks(), peaks);
assert_eq!(partial.forest(), peaks.num_leaves());
assert_eq!(partial.forest(), LEAVES.len());
assert_eq!(partial.peaks().num_peaks(), 3);
assert_eq!(partial.nodes.len(), 0);
// check state after adding tracking one element
let proof1 = mmr.open(0).unwrap();
let el1 = mmr.get(proof1.position).unwrap();
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
// check the number of nodes increased by the number of nodes in the proof
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
// check the values match
let idx = InOrderIndex::from_leaf_pos(proof1.position);
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[0]);
let idx = idx.parent();
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
let proof2 = mmr.open(1).unwrap();
let el2 = mmr.get(proof2.position).unwrap();
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
// check the number of nodes increased by a single element (the one that is not shared)
assert_eq!(partial.nodes.len(), 3);
// check the values match
let idx = InOrderIndex::from_leaf_pos(proof2.position);
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[0]);
let idx = idx.parent();
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[1]);
}
#[test]
fn test_partial_mmr_update_single() {
let mut full = Mmr::new();
let zero = int_to_node(0);
full.add(zero);
let mut partial: PartialMmr = full.peaks().into();
let proof = full.open(0).unwrap();
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
for i in 1..100 {
let node = int_to_node(i);
full.add(node);
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
partial.apply(delta).unwrap();
assert_eq!(partial.forest(), full.forest());
assert_eq!(partial.peaks(), full.peaks());
let proof1 = full.open(i as usize).unwrap();
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
let proof2 = partial.open(proof1.position).unwrap().unwrap();
assert_eq!(proof1.merkle_path, proof2.merkle_path);
}
}
#[test]
fn test_mmr_add_invalid_odd_leaf() {
let mmr: Mmr = LEAVES.into();
let acc = mmr.peaks();
let mut partial: PartialMmr = acc.clone().into();
let empty = MerklePath::new(Vec::new());
// None of the other leaves should work
for node in LEAVES.iter().cloned().rev().skip(1) {
let result = partial.track(LEAVES.len() - 1, node, &empty);
assert!(result.is_err());
}
let result = partial.track(LEAVES.len() - 1, LEAVES[6], &empty);
assert!(result.is_ok());
}
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
///
/// Here we manipulate the proof to return a peak index of 1 while the MMR only has 1 peak (with
/// index 0).
#[test]
#[should_panic]
fn test_mmr_proof_num_peaks_exceeds_current_num_peaks() {
let mmr: Mmr = LEAVES[0..4].iter().cloned().into();
let mut proof = mmr.open(3).unwrap();
proof.forest = 5;
proof.position = 4;
mmr.peaks().verify(LEAVES[3], proof).unwrap();
}
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
///
/// We create an MmrProof for a leaf whose peak index to verify against is 1.
/// Then we add another leaf which results in an Mmr with just one peak due to trees
/// being merged. If we try to use the old proof against the new Mmr, we should get an error.
#[test]
#[should_panic]
fn test_mmr_old_proof_num_peaks_exceeds_current_num_peaks() {
let leaves_len = 3;
let mut mmr = Mmr::from(LEAVES[0..leaves_len].iter().cloned());
let leaf_idx = leaves_len - 1;
let proof = mmr.open(leaf_idx).unwrap();
assert!(mmr.peaks().verify(LEAVES[leaf_idx], proof.clone()).is_ok());
mmr.add(LEAVES[leaves_len]);
mmr.peaks().verify(LEAVES[leaf_idx], proof).unwrap();
}
mod property_tests {
use proptest::prelude::*;
use super::leaf_to_corresponding_tree;
proptest! {
#[test]
fn test_last_position_is_always_contained_in_the_last_tree(leaves in any::<usize>().prop_filter("cant have an empty tree", |v| *v != 0)) {
let last_pos = leaves - 1;
let lowest_bit = leaves.trailing_zeros();
assert_eq!(
leaf_to_corresponding_tree(last_pos, leaves),
Some(lowest_bit),
);
}
}
proptest! {
#[test]
fn test_contained_tree_is_always_power_of_two((leaves, pos) in any::<usize>().prop_flat_map(|v| (Just(v), 0..v))) {
let tree_bit = leaf_to_corresponding_tree(pos, leaves).expect("pos is smaller than leaves, there should always be a corresponding tree");
let mask = 1usize << tree_bit;
assert!(tree_bit < usize::BITS, "the result must be a bit in usize");
assert!(mask & leaves != 0, "the result should be a tree in leaves");
}
}
}
// HELPER FUNCTIONS
// ================================================================================================
fn digests_to_elements(digests: &[RpoDigest]) -> Vec<Felt> {
digests.iter().flat_map(Word::from).collect()
}
// short hand for the rpo hash, used to make test code more concise and easy to read
fn merge(l: RpoDigest, r: RpoDigest) -> RpoDigest {
Rpo256::merge(&[l, r])
}

View file

@ -1,62 +0,0 @@
//! Data structures related to Merkle trees based on RPO256 hash function.
use super::{
hash::rpo::{Rpo256, RpoDigest},
Felt, Word, EMPTY_WORD, ZERO,
};
// REEXPORTS
// ================================================================================================
mod empty_roots;
pub use empty_roots::EmptySubtreeRoots;
mod index;
pub use index::NodeIndex;
mod merkle_tree;
pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
mod path;
pub use path::{MerklePath, RootPath, ValuePath};
mod smt;
#[cfg(feature = "internal")]
pub use smt::{build_subtree_for_bench, SubtreeLeaf};
pub use smt::{
InnerNode, LeafIndex, MutationSet, NodeMutation, PartialSmt, SimpleSmt, Smt, SmtLeaf,
SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
mod mmr;
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
mod store;
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
mod node;
pub use node::InnerNodeInfo;
mod partial_mt;
pub use partial_mt::PartialMerkleTree;
mod error;
pub use error::MerkleError;
// HELPER FUNCTIONS
// ================================================================================================
#[cfg(test)]
const fn int_to_node(value: u64) -> RpoDigest {
RpoDigest::new([Felt::new(value), ZERO, ZERO, ZERO])
}
#[cfg(test)]
const fn int_to_leaf(value: u64) -> Word {
[Felt::new(value), ZERO, ZERO, ZERO]
}
#[cfg(test)]
fn digests_to_words(digests: &[RpoDigest]) -> alloc::vec::Vec<Word> {
digests.iter().map(|d| d.into()).collect()
}

View file

@ -1,11 +0,0 @@
use super::RpoDigest;
/// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(test, derive(PartialOrd, Ord))]
pub struct InnerNodeInfo {
pub value: RpoDigest,
pub left: RpoDigest,
pub right: RpoDigest,
}

View file

@ -1,478 +0,0 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
string::String,
vec::Vec,
};
use core::fmt;
use super::{
InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Word,
EMPTY_WORD,
};
use crate::utils::{
word_to_hex, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
};
#[cfg(test)]
mod tests;
// CONSTANTS
// ================================================================================================
/// Index of the root node.
const ROOT_INDEX: NodeIndex = NodeIndex::root();
/// An RpoDigest consisting of 4 ZERO elements.
const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD);
// PARTIAL MERKLE TREE
// ================================================================================================
/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. Partial Merkle
/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths.
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct PartialMerkleTree {
max_depth: u8,
nodes: BTreeMap<NodeIndex, RpoDigest>,
leaves: BTreeSet<NodeIndex>,
}
impl Default for PartialMerkleTree {
fn default() -> Self {
Self::new()
}
}
impl PartialMerkleTree {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// Minimum supported depth.
pub const MIN_DEPTH: u8 = 1;
/// Maximum supported depth.
pub const MAX_DEPTH: u8 = 64;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new empty [PartialMerkleTree].
pub fn new() -> Self {
PartialMerkleTree {
max_depth: 0,
nodes: BTreeMap::new(),
leaves: BTreeSet::new(),
}
}
/// Appends the provided paths iterator into the set.
///
/// Analogous to [Self::add_path].
pub fn with_paths<I>(paths: I) -> Result<Self, MerkleError>
where
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
{
// create an empty tree
let tree = PartialMerkleTree::new();
paths.into_iter().try_fold(tree, |mut tree, (index, value, path)| {
tree.add_path(index, value, path)?;
Ok(tree)
})
}
/// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided
/// entries.
///
/// # Errors
/// Returns an error if:
/// - If the depth is 0 or is greater than 64.
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
/// - The provided entries contain an insufficient set of nodes.
pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
where
R: IntoIterator<IntoIter = I>,
I: Iterator<Item = (NodeIndex, RpoDigest)> + ExactSizeIterator,
{
let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
let mut leaves = BTreeSet::new();
let mut nodes = BTreeMap::new();
// add data to the leaves and nodes maps and also fill layers map, where the key is the
// depth of the node and value is its index.
for (node_index, hash) in entries.into_iter() {
leaves.insert(node_index);
nodes.insert(node_index, hash);
layers
.entry(node_index.depth())
.and_modify(|layer_vec| layer_vec.push(node_index.value()))
.or_insert(vec![node_index.value()]);
}
// check if the number of leaves can be accommodated by the tree's depth; we use a min
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
let max = 2usize.pow(63);
if layers.len() > max {
return Err(MerkleError::TooManyEntries(max));
}
// Get maximum depth
let max_depth = *layers.keys().next_back().unwrap_or(&0);
// fill layers without nodes with empty vector
for depth in 0..max_depth {
layers.entry(depth).or_default();
}
let mut layer_iter = layers.into_values().rev();
let mut parent_layer = layer_iter.next().unwrap();
let mut current_layer;
for depth in (1..max_depth + 1).rev() {
// set current_layer = parent_layer and parent_layer = layer_iter.next()
current_layer = layer_iter.next().unwrap();
core::mem::swap(&mut current_layer, &mut parent_layer);
for index_value in current_layer {
// get the parent node index
let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
// Check if the parent hash was already calculated. In about half of the cases, we
// don't need to do anything.
if !parent_layer.contains(&parent_node.value()) {
// create current node index
let index = NodeIndex::new(depth, index_value)?;
// get hash of the current node
let node =
nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
// get hash of the sibling node
let sibling = nodes
.get(&index.sibling())
.ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
// get parent hash
let parent = Rpo256::merge(&index.build_node(*node, *sibling));
// add index value of the calculated node to the parents layer
parent_layer.push(parent_node.value());
// add index and hash to the nodes map
nodes.insert(parent_node, parent);
}
}
}
Ok(PartialMerkleTree { max_depth, nodes, leaves })
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the root of this Merkle tree.
pub fn root(&self) -> RpoDigest {
self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST)
}
/// Returns the depth of this Merkle tree.
pub fn max_depth(&self) -> u8 {
self.max_depth
}
/// Returns a node at the specified NodeIndex.
///
/// # Errors
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
self.nodes
.get(&index)
.ok_or(MerkleError::NodeIndexNotFoundInTree(index))
.copied()
}
/// Returns true if provided index contains in the leaves set, false otherwise.
pub fn is_leaf(&self, index: NodeIndex) -> bool {
self.leaves.contains(&index)
}
/// Returns a vector of paths from every leaf to the root.
pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> {
let mut paths = Vec::new();
self.leaves.iter().for_each(|&leaf| {
paths.push((
leaf,
ValuePath {
value: self.get_node(leaf).expect("Failed to get leaf node"),
path: self.get_path(leaf).expect("Failed to get path"),
},
));
});
paths
}
/// Returns a Merkle path from the node at the specified index to the root.
///
/// The node itself is not included in the path.
///
/// # Errors
/// Returns an error if:
/// - the specified index has depth set to 0 or the depth is greater than the depth of this
/// Merkle tree.
/// - the specified index is not contained in the nodes map.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.max_depth() {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
}
if !self.nodes.contains_key(&index) {
return Err(MerkleError::NodeIndexNotFoundInTree(index));
}
let mut path = Vec::new();
for _ in 0..index.depth() {
let sibling_index = index.sibling();
index.move_up();
let sibling =
self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map");
path.push(sibling);
}
Ok(MerklePath::new(path))
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [PartialMerkleTree].
pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
self.leaves.iter().map(|&leaf| {
(
leaf,
self.get_node(leaf)
.unwrap_or_else(|_| panic!("Leaf with {leaf} is not in the nodes map")),
)
})
}
/// Returns an iterator over the inner nodes of this Merkle tree.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
inner_nodes.map(|(index, digest)| {
let left_hash =
self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
let right_hash =
self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
InnerNodeInfo {
value: *digest,
left: *left_hash,
right: *right_hash,
}
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Adds the nodes of the specified Merkle path to this [PartialMerkleTree]. The `index_value`
/// and `value` parameters specify the leaf node at which the path starts.
///
/// # Errors
/// Returns an error if:
/// - The depth of the specified node_index is greater than 64 or smaller than 1.
/// - The specified path is not consistent with other paths in the set (i.e., resolves to a
/// different root).
pub fn add_path(
&mut self,
index_value: u64,
value: RpoDigest,
path: MerklePath,
) -> Result<(), MerkleError> {
let index_value = NodeIndex::new(path.len() as u8, index_value)?;
Self::check_depth(index_value.depth())?;
self.update_depth(index_value.depth());
// add provided node and its sibling to the leaves set
self.leaves.insert(index_value);
let sibling_node_index = index_value.sibling();
self.leaves.insert(sibling_node_index);
// add provided node and its sibling to the nodes map
self.nodes.insert(index_value, value);
self.nodes.insert(sibling_node_index, path[0]);
// traverse to the root, updating the nodes
let mut index_value = index_value;
let node = Rpo256::merge(&index_value.build_node(value, path[0]));
let root = path.iter().skip(1).copied().fold(node, |node, hash| {
index_value.move_up();
// insert calculated node to the nodes map
self.nodes.insert(index_value, node);
// if the calculated node was a leaf, remove it from leaves set.
self.leaves.remove(&index_value);
let sibling_node = index_value.sibling();
// Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only
// if it is a new node (it wasn't in nodes map).
// Node can be in 3 states: internal node, leaf of the tree and not a tree node at all.
// - Internal node can only stay in this state -- addition of a new path can't make it
// a leaf or remove it from the tree.
// - Leaf node can stay in the same state (remain a leaf) or can become an internal
// node. In the first case we don't need to do anything, and the second case is handled
// by the call of `self.leaves.remove(&index_value);`
// - New node can be a calculated node or a "sibling" node from a Merkle Path:
// --- Calculated node, obviously, never can be a leaf.
// --- Sibling node can be only a leaf, because otherwise it is not a new node.
if self.nodes.insert(sibling_node, hash).is_none() {
self.leaves.insert(sibling_node);
}
Rpo256::merge(&index_value.build_node(node, hash))
});
// if the path set is empty (the root is all ZEROs), set the root to the root of the added
// path; otherwise, the root of the added path must be identical to the current root
if self.root() == EMPTY_DIGEST {
self.nodes.insert(ROOT_INDEX, root);
} else if self.root() != root {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: root,
});
}
Ok(())
}
/// Updates value of the leaf at the specified index returning the old leaf value.
///
/// By default the specified index is assumed to belong to the deepest layer. If the considered
/// node does not belong to the tree, the first node on the way to the root will be changed.
///
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
///
/// # Errors
/// Returns an error if:
/// - No entry exists at the specified index.
/// - The specified index is greater than the maximum number of nodes on the deepest layer.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
let mut node_index = NodeIndex::new(self.max_depth(), index)?;
// proceed to the leaf
for _ in 0..node_index.depth() {
if !self.leaves.contains(&node_index) {
node_index.move_up();
}
}
// add node value to the nodes Map
let old_value = self
.nodes
.insert(node_index, value.into())
.ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
// if the old value and new value are the same, there is nothing to update
if value == *old_value {
return Ok(old_value);
}
let mut value = value.into();
for _ in 0..node_index.depth() {
let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
value = Rpo256::merge(&node_index.build_node(value, *sibling));
node_index.move_up();
self.nodes.insert(node_index, value);
}
Ok(old_value)
}
// UTILITY FUNCTIONS
// --------------------------------------------------------------------------------------------
/// Utility to visualize a [PartialMerkleTree] in text.
pub fn print(&self) -> Result<String, fmt::Error> {
let indent = " ";
let mut s = String::new();
s.push_str("root: ");
s.push_str(&word_to_hex(&self.root())?);
s.push('\n');
for d in 1..=self.max_depth() {
let entries = 2u64.pow(d.into());
for i in 0..entries {
let index = NodeIndex::new(d, i).expect("The index must always be valid");
let node = self.get_node(index);
let node = match node {
Err(_) => continue,
Ok(node) => node,
};
for _ in 0..d {
s.push_str(indent);
}
s.push_str(&format!("({}, {}): ", index.depth(), index.value()));
s.push_str(&word_to_hex(&node)?);
s.push('\n');
}
}
Ok(s)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Updates depth value with the maximum of current and provided depth.
fn update_depth(&mut self, new_depth: u8) {
self.max_depth = new_depth.max(self.max_depth);
}
/// Returns an error if the depth is 0 or is greater than 64.
fn check_depth(depth: u8) -> Result<(), MerkleError> {
// validate the range of the depth.
if depth < Self::MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(depth));
} else if Self::MAX_DEPTH < depth {
return Err(MerkleError::DepthTooBig(depth as u64));
}
Ok(())
}
}
// SERIALIZATION
// ================================================================================================
impl Serializable for PartialMerkleTree {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
// write leaf nodes
target.write_u64(self.leaves.len() as u64);
for leaf_index in self.leaves.iter() {
leaf_index.write_into(target);
self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
}
}
}
impl Deserializable for PartialMerkleTree {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let leaves_len = source.read_u64()? as usize;
let mut leaf_nodes = Vec::with_capacity(leaves_len);
// add leaf nodes to the vector
for _ in 0..leaves_len {
let index = NodeIndex::read_from(source)?;
let hash = RpoDigest::read_from(source)?;
leaf_nodes.push((index, hash));
}
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
})?;
Ok(pmt)
}
}

View file

@ -1,466 +0,0 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{
super::{
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
PartialMerkleTree,
},
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath,
};
// TEST DATA
// ================================================================================================
const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0);
const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1);
const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0);
const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1);
const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2);
const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3);
const NODE30: NodeIndex = NodeIndex::new_unchecked(3, 0);
const NODE31: NodeIndex = NodeIndex::new_unchecked(3, 1);
const NODE32: NodeIndex = NodeIndex::new_unchecked(3, 2);
const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3);
const VALUES8: [RpoDigest; 8] = [
int_to_node(30),
int_to_node(31),
int_to_node(32),
int_to_node(33),
int_to_node(34),
int_to_node(35),
int_to_node(36),
int_to_node(37),
];
// TESTS
// ================================================================================================
// For the Partial Merkle Tree tests we will use parts of the Merkle Tree which full form is
// illustrated below:
//
// __________ root __________
// / \
// ____ 10 ____ ____ 11 ____
// / \ / \
// 20 21 22 23
// / \ / \ / \ / \
// (30) (31) (32) (33) (34) (35) (36) (37)
//
// Where node number is a concatenation of its depth and index. For example, node with
// NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis
// (33).
/// Checks that creation of the PMT with `with_leaves()` constructor is working correctly.
#[test]
fn with_leaves() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let leaf_nodes_vec = vec![
(NODE20, mt.get_node(NODE20).unwrap()),
(NODE32, mt.get_node(NODE32).unwrap()),
(NODE33, mt.get_node(NODE33).unwrap()),
(NODE22, mt.get_node(NODE22).unwrap()),
(NODE23, mt.get_node(NODE23).unwrap()),
];
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).unwrap();
assert_eq!(expected_root, pmt.root())
}
/// Checks that `with_leaves()` function returns an error when using incomplete set of nodes.
#[test]
fn err_with_leaves() {
// NODE22 is missing
let leaf_nodes_vec = vec![
(NODE20, int_to_node(20)),
(NODE32, int_to_node(32)),
(NODE33, int_to_node(33)),
(NODE23, int_to_node(23)),
];
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
assert!(PartialMerkleTree::with_leaves(leaf_nodes).is_err());
}
/// Checks that root returned by `root()` function is equal to the expected one.
#[test]
fn get_root() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert_eq!(expected_root, pmt.root());
}
/// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a
/// PMT using `add_path()` by adding Merkle Paths from node 33 and node 22 to the empty PMT. Then
/// it checks that paths returned by `get_path()` function are equal to the expected ones.
#[test]
fn add_and_get_paths() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let expected_path33 = ms.get_path(expected_root, NODE33).unwrap();
let expected_path22 = ms.get_path(expected_root, NODE22).unwrap();
let mut pmt = PartialMerkleTree::new();
pmt.add_path(3, expected_path33.value, expected_path33.path.clone()).unwrap();
pmt.add_path(2, expected_path22.value, expected_path22.path.clone()).unwrap();
let path33 = pmt.get_path(NODE33).unwrap();
let path22 = pmt.get_path(NODE22).unwrap();
let actual_root = pmt.root();
assert_eq!(expected_path33.path, path33);
assert_eq!(expected_path22.path, path22);
assert_eq!(expected_root, actual_root);
}
/// Checks that function `get_node` used on nodes 10 and 32 returns expected values.
#[test]
fn get_node() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert_eq!(ms.get_node(expected_root, NODE32).unwrap(), pmt.get_node(NODE32).unwrap());
assert_eq!(ms.get_node(expected_root, NODE10).unwrap(), pmt.get_node(NODE10).unwrap());
}
/// Updates leaves of the PMT using `update_leaf()` function and checks that new root of the tree
/// is equal to the expected one.
#[test]
fn update_leaf() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let root = mt.root();
let mut ms = MerkleStore::from(&mt);
let path33 = ms.get_path(root, NODE33).unwrap();
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
let new_value32 = int_to_node(132);
let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root;
pmt.update_leaf(2, *new_value32).unwrap();
let actual_root = pmt.root();
assert_eq!(expected_root, actual_root);
let new_value20 = int_to_node(120);
let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root;
pmt.update_leaf(0, *new_value20).unwrap();
let actual_root = pmt.root();
assert_eq!(expected_root, actual_root);
let new_value11 = int_to_node(111);
let expected_root = ms.set_node(expected_root, NODE11, new_value11).unwrap().root;
pmt.update_leaf(6, *new_value11).unwrap();
let actual_root = pmt.root();
assert_eq!(expected_root, actual_root);
}
/// Checks that paths of the PMT returned by `paths()` function are equal to the expected ones.
#[test]
fn get_paths() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let path22 = ms.get_path(expected_root, NODE22).unwrap();
let mut pmt = PartialMerkleTree::new();
pmt.add_path(3, path33.value, path33.path).unwrap();
pmt.add_path(2, path22.value, path22.path).unwrap();
// After PMT creation with path33 (33; 32, 20, 11) and path22 (22; 23, 10) we will have this
// tree:
//
// ______root______
// / \
// ___10___ ___11___
// / \ / \
// (20) 21 (22) (23)
// / \
// (32) (33)
//
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
// for each leaf.
let leaves = [NODE20, NODE22, NODE23, NODE32, NODE33];
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
.iter()
.map(|&leaf| {
(
leaf,
ValuePath {
value: mt.get_node(leaf).unwrap(),
path: mt.get_path(leaf).unwrap(),
},
)
})
.collect();
let actual_paths = pmt.to_paths();
assert_eq!(expected_paths, actual_paths);
}
// Checks correctness of leaves determination when using the `leaves()` function.
#[test]
fn leaves() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let path22 = ms.get_path(expected_root, NODE22).unwrap();
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
// After PMT creation with path33 (33; 32, 20, 11) we will have this tree:
//
// ______root______
// / \
// ___10___ (11)
// / \
// (20) 21
// / \
// (32) (33)
//
// Which have leaf nodes 11, 20, 32 and 33.
let value11 = mt.get_node(NODE11).unwrap();
let value20 = mt.get_node(NODE20).unwrap();
let value32 = mt.get_node(NODE32).unwrap();
let value33 = mt.get_node(NODE33).unwrap();
let leaves = [(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
let expected_leaves = leaves.iter().copied();
assert!(expected_leaves.eq(pmt.leaves()));
pmt.add_path(2, path22.value, path22.path).unwrap();
// After adding the path22 (22; 23, 10) to the existing PMT we will have this tree:
//
// ______root______
// / \
// ___10___ ___11___
// / \ / \
// (20) 21 (22) (23)
// / \
// (32) (33)
//
// Which have leaf nodes 20, 22, 23, 32 and 33.
let value20 = mt.get_node(NODE20).unwrap();
let value22 = mt.get_node(NODE22).unwrap();
let value23 = mt.get_node(NODE23).unwrap();
let value32 = mt.get_node(NODE32).unwrap();
let value33 = mt.get_node(NODE33).unwrap();
let leaves = vec![
(NODE20, value20),
(NODE22, value22),
(NODE23, value23),
(NODE32, value32),
(NODE33, value33),
];
let expected_leaves = leaves.iter().copied();
assert!(expected_leaves.eq(pmt.leaves()));
}
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected
/// ones.
#[test]
fn test_inner_node_iterator() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let path22 = ms.get_path(expected_root, NODE22).unwrap();
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
// get actual inner nodes
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
let expected_n00 = mt.root();
let expected_n10 = mt.get_node(NODE10).unwrap();
let expected_n11 = mt.get_node(NODE11).unwrap();
let expected_n20 = mt.get_node(NODE20).unwrap();
let expected_n21 = mt.get_node(NODE21).unwrap();
let expected_n32 = mt.get_node(NODE32).unwrap();
let expected_n33 = mt.get_node(NODE33).unwrap();
// create vector of the expected inner nodes
let mut expected = vec![
InnerNodeInfo {
value: expected_n00,
left: expected_n10,
right: expected_n11,
},
InnerNodeInfo {
value: expected_n10,
left: expected_n20,
right: expected_n21,
},
InnerNodeInfo {
value: expected_n21,
left: expected_n32,
right: expected_n33,
},
];
assert_eq!(actual, expected);
// add another path to the Partial Merkle Tree
pmt.add_path(2, path22.value, path22.path).unwrap();
// get new actual inner nodes
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
let expected_n22 = mt.get_node(NODE22).unwrap();
let expected_n23 = mt.get_node(NODE23).unwrap();
let info_11 = InnerNodeInfo {
value: expected_n11,
left: expected_n22,
right: expected_n23,
};
// add new inner node to the existing vertor
expected.insert(2, info_11);
assert_eq!(actual, expected);
}
/// Checks that serialization and deserialization implementations for the PMT are working
/// correctly.
#[test]
fn serialization() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let path22 = ms.get_path(expected_root, NODE22).unwrap();
let pmt = PartialMerkleTree::with_paths([
(3, path33.value, path33.path),
(2, path22.value, path22.path),
])
.unwrap();
let serialized_pmt = pmt.to_bytes();
let deserialized_pmt = PartialMerkleTree::read_from_bytes(&serialized_pmt).unwrap();
assert_eq!(deserialized_pmt, pmt);
}
/// Checks that deserialization fails with incorrect data.
#[test]
fn err_deserialization() {
let mut tree_bytes: Vec<u8> = vec![5];
tree_bytes.append(&mut NODE20.to_bytes());
tree_bytes.append(&mut int_to_node(20).to_bytes());
tree_bytes.append(&mut NODE21.to_bytes());
tree_bytes.append(&mut int_to_node(21).to_bytes());
// node with depth 1 could have index 0 or 1, but it has 2
tree_bytes.append(&mut vec![1, 2]);
tree_bytes.append(&mut int_to_node(11).to_bytes());
assert!(PartialMerkleTree::read_from_bytes(&tree_bytes).is_err());
}
/// Checks that addition of the path with different root will cause an error.
#[test]
fn err_add_path() {
let path33 = vec![int_to_node(1), int_to_node(2), int_to_node(3)].into();
let path22 = vec![int_to_node(4), int_to_node(5)].into();
let mut pmt = PartialMerkleTree::new();
pmt.add_path(3, int_to_node(6), path33).unwrap();
assert!(pmt.add_path(2, int_to_node(7), path22).is_err());
}
/// Checks that the request of the node which is not in the PMT will cause an error.
#[test]
fn err_get_node() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert!(pmt.get_node(NODE22).is_err());
assert!(pmt.get_node(NODE23).is_err());
assert!(pmt.get_node(NODE30).is_err());
assert!(pmt.get_node(NODE31).is_err());
}
/// Checks that the request of the path from the leaf which is not in the PMT will cause an error.
#[test]
fn err_get_path() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert!(pmt.get_path(NODE22).is_err());
assert!(pmt.get_path(NODE23).is_err());
assert!(pmt.get_path(NODE30).is_err());
assert!(pmt.get_path(NODE31).is_err());
}
#[test]
fn err_update_leaf() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let expected_root = mt.root();
let ms = MerkleStore::from(&mt);
let path33 = ms.get_path(expected_root, NODE33).unwrap();
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
assert!(pmt.update_leaf(8, *int_to_node(38)).is_err());
}

View file

@ -1,282 +0,0 @@
use alloc::vec::Vec;
use core::ops::{Deref, DerefMut};
use super::{InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest};
use crate::{
utils::{ByteReader, Deserializable, DeserializationError, Serializable},
Word,
};
// MERKLE PATH
// ================================================================================================
/// A merkle path container, composed of a sequence of nodes of a Merkle tree.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerklePath {
nodes: Vec<RpoDigest>,
}
impl MerklePath {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Creates a new Merkle path from a list of nodes.
pub fn new(nodes: Vec<RpoDigest>) -> Self {
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
Self { nodes }
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
/// Returns the depth in which this Merkle path proof is valid.
pub fn depth(&self) -> u8 {
self.nodes.len() as u8
}
/// Returns a reference to the [MerklePath]'s nodes.
pub fn nodes(&self) -> &[RpoDigest] {
&self.nodes
}
/// Computes the merkle root for this opening.
pub fn compute_root(&self, index: u64, node: RpoDigest) -> Result<RpoDigest, MerkleError> {
let mut index = NodeIndex::new(self.depth(), index)?;
let root = self.nodes.iter().copied().fold(node, |node, sibling| {
// compute the node and move to the next iteration.
let input = index.build_node(node, sibling);
index.move_up();
Rpo256::merge(&input)
});
Ok(root)
}
/// Verifies the Merkle opening proof towards the provided root.
///
/// # Errors
/// Returns an error if:
/// - provided node index is invalid.
/// - root calculated during the verification differs from the provided one.
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> Result<(), MerkleError> {
let computed_root = self.compute_root(index, node)?;
if &computed_root != root {
return Err(MerkleError::ConflictingRoots {
expected_root: *root,
actual_root: computed_root,
});
}
Ok(())
}
/// Returns an iterator over every inner node of this [MerklePath].
///
/// The iteration order is unspecified.
///
/// # Errors
/// Returns an error if the specified index is not valid for this path.
pub fn inner_nodes(
&self,
index: u64,
node: RpoDigest,
) -> Result<InnerNodeIterator, MerkleError> {
Ok(InnerNodeIterator {
nodes: &self.nodes,
index: NodeIndex::new(self.depth(), index)?,
value: node,
})
}
}
// CONVERSIONS
// ================================================================================================
impl From<MerklePath> for Vec<RpoDigest> {
fn from(path: MerklePath) -> Self {
path.nodes
}
}
impl From<Vec<RpoDigest>> for MerklePath {
fn from(path: Vec<RpoDigest>) -> Self {
Self::new(path)
}
}
impl From<&[RpoDigest]> for MerklePath {
fn from(path: &[RpoDigest]) -> Self {
Self::new(path.to_vec())
}
}
impl Deref for MerklePath {
// we use `Vec` here instead of slice so we can call vector mutation methods directly from the
// merkle path (example: `Vec::remove`).
type Target = Vec<RpoDigest>;
fn deref(&self) -> &Self::Target {
&self.nodes
}
}
impl DerefMut for MerklePath {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.nodes
}
}
// ITERATORS
// ================================================================================================
impl FromIterator<RpoDigest> for MerklePath {
fn from_iter<T: IntoIterator<Item = RpoDigest>>(iter: T) -> Self {
Self::new(iter.into_iter().collect())
}
}
impl IntoIterator for MerklePath {
type Item = RpoDigest;
type IntoIter = alloc::vec::IntoIter<RpoDigest>;
fn into_iter(self) -> Self::IntoIter {
self.nodes.into_iter()
}
}
/// An iterator over internal nodes of a [MerklePath].
pub struct InnerNodeIterator<'a> {
nodes: &'a Vec<RpoDigest>,
index: NodeIndex,
value: RpoDigest,
}
impl Iterator for InnerNodeIterator<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
if !self.index.is_root() {
let sibling_pos = self.nodes.len() - self.index.depth() as usize;
let (left, right) = if self.index.is_value_odd() {
(self.nodes[sibling_pos], self.value)
} else {
(self.value, self.nodes[sibling_pos])
};
self.value = Rpo256::merge(&[left, right]);
self.index.move_up();
Some(InnerNodeInfo { value: self.value, left, right })
} else {
None
}
}
}
// MERKLE PATH CONTAINERS
// ================================================================================================
/// A container for a [crate::Word] value and its [MerklePath] opening.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ValuePath {
/// The node value opening for `path`.
pub value: RpoDigest,
/// The path from `value` to `root` (exclusive).
pub path: MerklePath,
}
impl ValuePath {
/// Returns a new [ValuePath] instantiated from the specified value and path.
pub fn new(value: RpoDigest, path: MerklePath) -> Self {
Self { value, path }
}
}
impl From<(MerklePath, Word)> for ValuePath {
fn from((path, value): (MerklePath, Word)) -> Self {
ValuePath::new(value.into(), path)
}
}
/// A container for a [MerklePath] and its [crate::Word] root.
///
/// This structure does not provide any guarantees regarding the correctness of the path to the
/// root. For more information, check [MerklePath::verify].
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct RootPath {
/// The node value opening for `path`.
pub root: RpoDigest,
/// The path from `value` to `root` (exclusive).
pub path: MerklePath,
}
// SERIALIZATION
// ================================================================================================
impl Serializable for MerklePath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
target.write_u8(self.nodes.len() as u8);
target.write_many(&self.nodes);
}
}
impl Deserializable for MerklePath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let count = source.read_u8()?.into();
let nodes = source.read_many::<RpoDigest>(count)?;
Ok(Self { nodes })
}
}
impl Serializable for ValuePath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.value.write_into(target);
self.path.write_into(target);
}
}
impl Deserializable for ValuePath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let value = RpoDigest::read_from(source)?;
let path = MerklePath::read_from(source)?;
Ok(Self { value, path })
}
}
impl Serializable for RootPath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.root.write_into(target);
self.path.write_into(target);
}
}
impl Deserializable for RootPath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let root = RpoDigest::read_from(source)?;
let path = MerklePath::read_from(source)?;
Ok(Self { root, path })
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use crate::merkle::{int_to_node, MerklePath};
#[test]
fn test_inner_nodes() {
let nodes = vec![int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
let merkle_path = MerklePath::new(nodes);
let index = 6;
let node = int_to_node(5);
let root = merkle_path.compute_root(index, node).unwrap();
let inner_root = merkle_path.inner_nodes(index, node).unwrap().last().unwrap().value;
assert_eq!(root, inner_root);
}
}

View file

@ -1,605 +0,0 @@
use alloc::{collections::BTreeSet, vec::Vec};
use core::mem;
use num::Integer;
use super::{
EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet,
NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
};
use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap};
#[cfg(test)]
mod tests;
type MutatedSubtreeLeaves = Vec<Vec<SubtreeLeaf>>;
// CONCURRENT IMPLEMENTATIONS
// ================================================================================================
impl Smt {
/// Parallel implementation of [`Smt::with_entries()`].
///
/// This method constructs a new sparse Merkle tree concurrently by processing subtrees in
/// parallel, working from the bottom up. The process works as follows:
///
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
///
/// 2. The subtrees are then processed in parallel:
/// - For each subtree, compute the inner nodes from depth D down to depth D-8.
/// - Each subtree computation yields a new subtree root and its associated inner nodes.
///
/// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration,
/// which processes the next 8 levels up. This continues until the final root of the tree is
/// computed at depth 0.
pub(crate) fn with_entries_concurrent(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
let mut seen_keys = BTreeSet::new();
let entries: Vec<_> = entries
.into_iter()
.map(|(key, value)| {
if seen_keys.insert(key) {
Ok((key, value))
} else {
Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
))
}
})
.collect::<Result<_, _>>()?;
if entries.is_empty() {
return Ok(Self::default());
}
let (inner_nodes, leaves) = Self::build_subtrees(entries);
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
}
/// Parallel implementation of [`Smt::compute_mutations()`].
///
/// This method computes mutations by recursively processing subtrees in parallel, working from
/// the bottom up. The process works as follows:
///
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
///
/// 2. The subtrees containing modifications are then processed in parallel:
/// - For each modified subtree, compute node mutations from depth D up to depth D-8
/// - Each subtree computation yields a new root at depth D-8 and its associated mutations
///
/// 3. These subtree roots become the "leaves" for the next iteration, which processes the next
/// 8 levels up. This continues until reaching the tree's root at depth 0.
pub(crate) fn compute_mutations_concurrent(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word>
where
Self: Sized + Sync,
{
use rayon::prelude::*;
// Collect and sort key-value pairs by their corresponding leaf index
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());
// Convert sorted pairs into mutated leaves and capture any new pairs
let (mut subtree_leaves, new_pairs) =
self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs);
let mut node_mutations = NodeMutations::default();
// Process each depth level in reverse, stepping by the subtree depth
for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// Parallel processing of each subtree to generate mutations and roots
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted() && !subtree.is_empty());
self.build_subtree_mutations(subtree, SMT_DEPTH, depth)
})
.unzip();
// Prepare leaves for the next depth level
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
// Aggregate all node mutations
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
debug_assert!(!subtree_leaves.is_empty());
}
// Finalize the mutation set with updated roots and mutations
MutationSet {
old_root: self.root(),
new_root: subtree_leaves[0][0].hash,
node_mutations,
new_pairs,
}
}
// SUBTREE MUTATION
// --------------------------------------------------------------------------------------------
/// Computes the node mutations and the root of a subtree
fn build_subtree_mutations(
&self,
mut leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
bottom_depth: u8,
) -> (NodeMutations, SubtreeLeaf)
where
Self: Sized,
{
debug_assert!(bottom_depth <= tree_depth);
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
let subtree_root_depth = bottom_depth - SUBTREE_DEPTH;
let mut node_mutations: NodeMutations = Default::default();
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
for current_depth in (subtree_root_depth..bottom_depth).rev() {
debug_assert!(current_depth <= bottom_depth);
let next_depth = current_depth + 1;
let mut iter = leaves.drain(..).peekable();
while let Some(first_leaf) = iter.next() {
// This constructs a valid index because next_depth will never exceed the depth of
// the tree.
let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent();
let parent_node = self.get_inner_node(parent_index);
let combined_node = fetch_sibling_pair(&mut iter, first_leaf, parent_node);
let combined_hash = combined_node.hash();
let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth);
// Add the parent node even if it is empty for proper upward updates
next_leaves.push(SubtreeLeaf {
col: parent_index.value(),
hash: combined_hash,
});
node_mutations.insert(
parent_index,
if combined_hash != empty_hash {
NodeMutation::Addition(combined_node)
} else {
NodeMutation::Removal
},
);
}
drop(iter);
leaves = mem::take(&mut next_leaves);
}
debug_assert_eq!(leaves.len(), 1);
let root_leaf = leaves.pop().unwrap();
(node_mutations, root_leaf)
}
// SUBTREE CONSTRUCTION
// --------------------------------------------------------------------------------------------
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
///
/// `entries` need not be sorted. This function will sort them.
fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
});
Self::build_subtrees_from_sorted_entries(entries)
}
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
///
/// This function is mostly an implementation detail of
/// [`Smt::with_entries_concurrent()`].
fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
use rayon::prelude::*;
let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) =
leaf_subtrees
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted());
debug_assert!(!subtree.is_empty());
let (nodes, subtree_root) =
build_subtree(subtree, SMT_DEPTH, current_depth);
(nodes, subtree_root)
})
.unzip();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
debug_assert!(!leaf_subtrees.is_empty());
}
(accumulated_nodes, initial_leaves)
}
// LEAF NODE CONSTRUCTION
// --------------------------------------------------------------------------------------------
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
/// the inputs to feed into [`build_subtree()`].
///
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
///
/// # Panics
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
/// sorted. Without debug assertions, the returned computations will be incorrect.
fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations<u64, SmtLeaf> {
Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
}
/// Computes leaves from a set of key-value pairs and current leaf values.
/// Derived from `sorted_pairs_to_leaves`
fn sorted_pairs_to_mutated_subtree_leaves(
&self,
pairs: Vec<(RpoDigest, Word)>,
) -> (MutatedSubtreeLeaves, UnorderedMap<RpoDigest, Word>) {
// Map to track new key-value pairs for mutated leaves
let mut new_pairs = UnorderedMap::new();
let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
let mut leaf = self.get_leaf(&leaf_pairs[0].0);
for (key, value) in leaf_pairs {
// Check if the value has changed
let old_value =
new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
// Skip if the value hasn't changed
if value == old_value {
continue;
}
// Otherwise, update the leaf and track the new key-value pair
leaf = self.construct_prospective_leaf(leaf, &key, &value);
new_pairs.insert(key, value);
}
leaf
});
(accumulator.leaves, new_pairs)
}
/// Processes sorted key-value pairs to compute leaves for a subtree.
///
/// This function groups key-value pairs by their corresponding column index and processes each
/// group to construct leaves. The actual construction of the leaf is delegated to the
/// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating
/// new leaves or mutating existing ones).
///
/// # Parameters
/// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index
/// column (not simply by key). If the input is not sorted correctly, the function will
/// produce incorrect results and may panic in debug mode.
/// - `process_leaf`: A callback function used to process each group of key-value pairs
/// corresponding to the same column index. The callback takes a vector of key-value pairs for
/// a single column and returns the constructed leaf for that column.
///
/// # Returns
/// A `PairComputations<u64, Self::Leaf>` containing:
/// - `nodes`: A mapping of column indices to the constructed leaves.
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
///
/// # Panics
/// This function will panic in debug mode if the input `pairs` are not sorted by column index.
fn process_sorted_pairs_to_leaves<F>(
pairs: Vec<(RpoDigest, Word)>,
mut process_leaf: F,
) -> PairComputations<u64, SmtLeaf>
where
F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf,
{
use rayon::prelude::*;
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
// out and store them in our accumulated leaves.
let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default();
let mut iter = pairs.into_iter().peekable();
while let Some((key, value)) = iter.next() {
let col = Self::key_to_leaf_index(&key).index.value();
let peeked_col = iter.peek().map(|(key, _v)| {
let index = Self::key_to_leaf_index(key);
let next_col = index.index.value();
// We panic if `pairs` is not sorted by column.
debug_assert!(next_col >= col);
next_col
});
current_leaf_buffer.push((key, value));
// If the next pair is the same column as this one, then we're done after adding this
// pair to the buffer.
if peeked_col == Some(col) {
continue;
}
// Otherwise, the next pair is a different column, or there is no next pair. Either way
// it's time to swap out our buffer.
let leaf_pairs = mem::take(&mut current_leaf_buffer);
let leaf = process_leaf(leaf_pairs);
accumulator.nodes.insert(col, leaf);
debug_assert!(current_leaf_buffer.is_empty());
}
// Compute the leaves from the nodes concurrently
let mut accumulated_leaves: Vec<SubtreeLeaf> = accumulator
.nodes
.clone()
.into_par_iter()
.map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) })
.collect();
// Sort the leaves by column
accumulated_leaves.par_sort_by_key(|leaf| leaf.col);
// TODO: determine is there is any notable performance difference between computing
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
// subtree boundaries as we go. Either way this function is only used at the beginning of a
// parallel construction, so it should not be a critical path.
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
accumulator
}
}
// SUBTREES
// ================================================================================================
/// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8;
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
///
/// Note that these represent "conceptual" leaves of some subtree, not necessarily
/// the leaf type for the sparse Merkle tree.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct SubtreeLeaf {
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
pub col: u64,
/// The hash of the node this `SubtreeLeaf` represents.
pub hash: RpoDigest,
}
/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone)]
pub(crate) struct PairComputations<K, L> {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes: UnorderedMap<K, L>,
/// "Conceptual" leaves that will be used for computations.
pub leaves: Vec<Vec<SubtreeLeaf>>,
}
// Derive requires `L` to impl Default, even though we don't actually need that.
impl<K, L> Default for PairComputations<K, L> {
fn default() -> Self {
Self {
nodes: Default::default(),
leaves: Default::default(),
}
}
}
#[derive(Debug)]
pub(crate) struct SubtreeLeavesIter<'s> {
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
}
impl<'s> SubtreeLeavesIter<'s> {
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
// TODO: determine if there is any notable performance difference between taking a Vec,
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
// The latter may have self-referential properties that are impossible to express in purely
// safe Rust Rust.
Self { leaves: leaves.drain(..).peekable() }
}
}
impl Iterator for SubtreeLeavesIter<'_> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
let mut subtree: Vec<SubtreeLeaf> = Default::default();
let mut last_subtree_col = 0;
while let Some(leaf) = self.leaves.peek() {
last_subtree_col = u64::max(1, last_subtree_col);
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
let next_subtree_col = if is_exact_multiple {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
last_subtree_col = leaf.col;
if leaf.col < next_subtree_col {
subtree.push(self.leaves.next().unwrap());
} else if subtree.is_empty() {
continue;
} else {
break;
}
}
if subtree.is_empty() {
debug_assert!(self.leaves.peek().is_none());
return None;
}
Some(subtree)
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
///
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself.
///
/// # Panics
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
fn build_subtree(
mut leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
bottom_depth: u8,
) -> (UnorderedMap<NodeIndex, InnerNode>, SubtreeLeaf) {
#[cfg(debug_assertions)]
{
// Ensure that all leaves have unique column indices within this subtree.
// In normal usage via public APIs, this should never happen because leaf
// construction enforces uniqueness. However, when testing or benchmarking
// `build_subtree()` in isolation, duplicate columns can appear if input
// constraints are not enforced.
let mut seen_cols = BTreeSet::new();
for leaf in &leaves {
assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col);
}
}
debug_assert!(bottom_depth <= tree_depth);
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
let subtree_root = bottom_depth - SUBTREE_DEPTH;
let mut inner_nodes: UnorderedMap<NodeIndex, InnerNode> = Default::default();
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
for next_depth in (subtree_root..bottom_depth).rev() {
debug_assert!(next_depth <= bottom_depth);
// `next_depth` is the stuff we're making.
// `current_depth` is the stuff we have.
let current_depth = next_depth + 1;
let mut iter = leaves.drain(..).peekable();
while let Some(first) = iter.next() {
// On non-continuous iterations, including the first iteration, `first_column` may
// be a left or right node. On subsequent continuous iterations, we will always call
// `iter.next()` twice.
// On non-continuous iterations (including the very first iteration), this column
// could be either on the left or the right. If the next iteration is not
// discontinuous with our right node, then the next iteration's
let is_right = first.col.is_odd();
let (left, right) = if is_right {
// Discontinuous iteration: we have no left node, so it must be empty.
let left = SubtreeLeaf {
col: first.col - 1,
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
};
let right = first;
(left, right)
} else {
let left = first;
let right_col = first.col + 1;
let right = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => {
// Our inputs must be sorted.
debug_assert!(left.col <= col);
// The next leaf in the iterator is our sibling. Use it and consume it!
iter.next().unwrap()
},
// Otherwise, the leaves don't contain our sibling, so our sibling must be
// empty.
_ => SubtreeLeaf {
col: right_col,
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
},
};
(left, right)
};
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
let node = InnerNode { left: left.hash, right: right.hash };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
// If this hash is empty, then it doesn't become a new inner node, nor does it count
// as a leaf for the next depth.
if hash != equivalent_empty_hash {
inner_nodes.insert(index, node);
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
}
}
// Stop borrowing `leaves`, so we can swap it.
// The iterator is empty at this point anyway.
drop(iter);
// After each depth, consider the stuff we just made the new "leaves", and empty the
// other collection.
mem::swap(&mut leaves, &mut next_leaves);
}
debug_assert_eq!(leaves.len(), 1);
let root = leaves.pop().unwrap();
(inner_nodes, root)
}
/// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part:
/// - If `first_leaf` is a right child, the left child is copied from the `parent_node`.
/// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also mutated
/// or copied from the `parent_node`.
///
/// Returns the `InnerNode` containing the hashes of the sibling pair.
fn fetch_sibling_pair(
iter: &mut core::iter::Peekable<alloc::vec::Drain<SubtreeLeaf>>,
first_leaf: SubtreeLeaf,
parent_node: InnerNode,
) -> InnerNode {
let is_right_node = first_leaf.col.is_odd();
if is_right_node {
let left_leaf = SubtreeLeaf {
col: first_leaf.col - 1,
hash: parent_node.left,
};
InnerNode {
left: left_leaf.hash,
right: first_leaf.hash,
}
} else {
let right_col = first_leaf.col + 1;
let right_leaf = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(),
_ => SubtreeLeaf { col: right_col, hash: parent_node.right },
};
InnerNode {
left: first_leaf.hash,
right: right_leaf.hash,
}
}
}
#[cfg(feature = "internal")]
pub fn build_subtree_for_bench(
leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
bottom_depth: u8,
) -> (UnorderedMap<NodeIndex, InnerNode>, SubtreeLeaf) {
build_subtree(leaves, tree_depth, bottom_depth)
}

View file

@ -1,459 +0,0 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use rand::{prelude::IteratorRandom, thread_rng, Rng};
use super::{
build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest,
Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE,
SMT_DEPTH, SUBTREE_DEPTH,
};
use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE};
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf {
col: leaf.index().index.value(),
hash: leaf.hash(),
}
}
#[test]
fn test_sorted_pairs_to_leaves() {
let entries: Vec<(RpoDigest, Word)> = vec![
// Subtree 0.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]),
// Leaf index collision.
(RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]),
// Subtree 1. Normal single leaf again.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]),
// Subtree 2. Another normal leaf.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
];
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let control_leaves: Vec<SmtLeaf> = {
let mut entries_iter = entries.iter().cloned();
let mut next_entry = || entries_iter.next().unwrap();
let control_leaves = vec![
// Subtree 0.
SmtLeaf::Single(next_entry()),
SmtLeaf::Single(next_entry()),
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
// Subtree 1.
SmtLeaf::Single(next_entry()),
SmtLeaf::Single(next_entry()),
// Subtree 2.
SmtLeaf::Single(next_entry()),
];
assert_eq!(entries_iter.next(), None);
control_leaves
};
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
let mut control_leaves_iter = control_leaves.iter();
let mut next_leaf = || control_leaves_iter.next().unwrap();
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
// Subtree 0.
vec![next_leaf(), next_leaf(), next_leaf()],
// Subtree 1.
vec![next_leaf(), next_leaf()],
// Subtree 2.
vec![next_leaf()],
]
.map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect())
.to_vec();
assert_eq!(control_leaves_iter.next(), None);
control_subtree_leaves
};
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
// This will check that the hashes, columns, and subtree assignments all match.
assert_eq!(subtrees.leaves, control_subtree_leaves);
// Flattening and re-separating out the leaves into subtrees should have the same result.
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
assert_eq!(subtrees.leaves, re_grouped);
// Then finally we might as well check the computed leaf nodes too.
let control_leaves: BTreeMap<u64, SmtLeaf> = control
.leaves()
.map(|(index, value)| (index.index.value(), value.clone()))
.collect();
for (column, test_leaf) in subtrees.nodes {
if test_leaf.is_empty() {
continue;
}
let control_leaf = control_leaves
.get(&column)
.unwrap_or_else(|| panic!("no leaf node found for column {column}"));
assert_eq!(control_leaf, &test_leaf);
}
}
// Helper for the below tests.
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
(0..pair_count)
.map(|i| {
let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64;
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect()
}
fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> {
const REMOVAL_PROBABILITY: f64 = 0.2;
let mut rng = thread_rng();
// Assertion to ensure input keys are unique
assert!(
entries.iter().map(|(key, _)| key).collect::<BTreeSet<_>>().len() == entries.len(),
"Input entries contain duplicate keys!"
);
let mut sorted_entries: Vec<(RpoDigest, Word)> = entries
.into_iter()
.choose_multiple(&mut rng, updates)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};
(key, value)
})
.collect();
sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value());
sorted_entries
}
#[test]
fn test_single_subtree() {
// A single subtree's worth of leaves.
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
// `entries` should already be sorted by nature of how we constructed it.
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
let leaves = leaves.into_iter().next().unwrap();
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
assert!(!first_subtree.is_empty());
// The inner nodes computed from that subtree should match the nodes in our control tree.
for (index, node) in first_subtree.into_iter() {
let control = control.get_inner_node(index);
assert_eq!(
control, node,
"subtree-computed node at index {index:?} does not match control",
);
}
// The root returned should also match the equivalent node in the control tree.
let control_root_index =
NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index");
let control_root_node = control.get_inner_node(control_root_index);
let control_hash = control_root_node.hash();
assert_eq!(
control_hash, subtree_root.hash,
"Subtree-computed root at index {control_root_index:?} does not match control"
);
}
// Test that not just can we compute a subtree correctly, but we can feed the results of one
// subtree into computing another. In other words, test that `build_subtree()` is correctly
// composable.
#[test]
fn test_two_subtrees() {
// Two subtrees' worth of leaves.
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
// With two subtrees' worth of leaves, we should have exactly two subtrees.
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
assert_eq!(first.len(), second.len());
let mut current_depth = SMT_DEPTH;
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth);
next_leaves.push(first_root);
let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth);
next_leaves.push(second_root);
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
assert_eq!(total_computed as u64, PAIR_COUNT);
// Verify the computed nodes of both subtrees.
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
for (index, test_node) in computed_nodes {
let control_node = control.get_inner_node(index);
assert_eq!(
control_node, test_node,
"subtree-computed node at index {index:?} does not match control",
);
}
current_depth -= SUBTREE_DEPTH;
let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth);
assert_eq!(nodes.len(), SUBTREE_DEPTH as usize);
assert_eq!(root_leaf.col, 0);
for (index, test_node) in nodes {
let control_node = control.get_inner_node(index);
assert_eq!(
control_node, test_node,
"subtree-computed node at index {index:?} does not match control",
);
}
let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap();
let control_root = control.get_inner_node(index).hash();
assert_eq!(control_root, root_leaf.hash, "Root mismatch");
}
#[test]
fn test_singlethreaded_subtrees() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// There's no flat_map_unzip(), so this is the best we can do.
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(
!subtree.is_empty(),
"subtree {i} at bottom-depth {current_depth} is empty!",
);
// Do actual things.
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
// Post-assertions.
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
(nodes, subtree_root)
})
.unzip();
// Update state between each depth iteration.
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, first checking length and then checking each individual
// leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let index = LeafIndex::new_max_depth(col);
let &control_leaf = control_leaves.get(&index).unwrap();
assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control");
}
// Make sure the inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root node actually
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), root_leaf.hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash);
}
/// The parallel version of `test_singlethreaded_subtree()`.
#[test]
fn test_multithreaded_subtrees() {
use rayon::prelude::*;
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_par_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(
!subtree.is_empty(),
"subtree {i} at bottom-depth {current_depth} is empty!",
);
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
// Post-assertions.
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
(nodes, subtree_root)
})
.unzip();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, checking length first and then each individual leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let index = LeafIndex::new_max_depth(col);
let &control_leaf = control_leaves.get(&index).unwrap();
assert_eq!(test_leaf, control_leaf);
}
// Make sure the inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root node actually
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), root_leaf.hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash);
}
#[test]
fn test_with_entries_concurrent() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let smt = Smt::with_entries(entries.clone()).unwrap();
assert_eq!(smt.root(), control.root());
assert_eq!(smt, control);
}
/// Concurrent mutations
#[test]
fn test_singlethreaded_subtree_mutations() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let updates = generate_updates(entries.clone(), 1000);
let tree = Smt::with_entries_sequential(entries.clone()).unwrap();
let control = tree.compute_mutations_sequential(updates.clone());
let mut node_mutations = NodeMutations::default();
let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates);
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// There's no flat_map_unzip(), so this is the best we can do.
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
.into_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(
!subtree.is_empty(),
"subtree {i} at bottom-depth {current_depth} is empty!",
);
// Calculate the mutations for this subtree.
let (mutations_per_subtree, subtree_root) =
tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth);
// Check that the mutations match the control tree.
for (&index, mutation) in mutations_per_subtree.iter() {
let control_mutation = control.node_mutations().get(&index).unwrap();
assert_eq!(
control_mutation, mutation,
"depth {} subtree {}: mutation does not match control at index {:?}",
current_depth, i, index,
);
}
(mutations_per_subtree, subtree_root)
})
.unzip();
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
assert!(!subtree_leaves.is_empty(), "on depth {current_depth}");
}
let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap();
let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap();
// Check that the new root matches the control.
assert_eq!(control.new_root, root_leaf.hash);
// Check that the node mutations match the control.
assert_eq!(control.node_mutations().len(), node_mutations.len());
for (&index, mutation) in control.node_mutations().iter() {
let test_mutation = node_mutations.get(&index).unwrap();
assert_eq!(test_mutation, mutation);
}
// Check that the new pairs match the control
assert_eq!(control.new_pairs.len(), new_pairs.len());
for (&key, &value) in control.new_pairs.iter() {
let test_value = new_pairs.get(&key).unwrap();
assert_eq!(test_value, &value);
}
}
#[test]
fn test_compute_mutations_parallel() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let tree = Smt::with_entries(entries.clone()).unwrap();
let updates = generate_updates(entries, 1000);
let control = tree.compute_mutations_sequential(updates.clone());
let mutations = tree.compute_mutations(updates);
assert_eq!(mutations.root(), control.root());
assert_eq!(mutations.old_root(), control.old_root());
assert_eq!(mutations.node_mutations(), control.node_mutations());
assert_eq!(mutations.new_pairs(), control.new_pairs());
}

View file

@ -1,39 +0,0 @@
use thiserror::Error;
use crate::{
hash::rpo::RpoDigest,
merkle::{LeafIndex, SMT_DEPTH},
};
// SMT LEAF ERROR
// =================================================================================================
#[derive(Debug, Error)]
pub enum SmtLeafError {
#[error(
"multiple leaf requires all keys to map to the same leaf index but key1 {key_1} and key2 {key_2} map to different indices"
)]
InconsistentMultipleLeafKeys { key_1: RpoDigest, key_2: RpoDigest },
#[error("single leaf key {key} maps to {actual_leaf_index:?} but was expected to map to {expected_leaf_index:?}")]
InconsistentSingleLeafIndices {
key: RpoDigest,
expected_leaf_index: LeafIndex<SMT_DEPTH>,
actual_leaf_index: LeafIndex<SMT_DEPTH>,
},
#[error("supplied leaf index {leaf_index_supplied:?} does not match {leaf_index_from_keys:?} for multiple leaf")]
InconsistentMultipleLeafIndices {
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
},
#[error("multiple leaf requires at least two entries but only {0} were given")]
MultipleLeafRequiresTwoEntries(usize),
}
// SMT PROOF ERROR
// =================================================================================================
#[derive(Debug, Error)]
pub enum SmtProofError {
#[error("merkle path length {0} does not match SMT depth {SMT_DEPTH}")]
InvalidMerklePathLength(usize),
}

View file

@ -1,373 +0,0 @@
use alloc::{string::ToString, vec::Vec};
use core::cmp::Ordering;
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum SmtLeaf {
Empty(LeafIndex<SMT_DEPTH>),
Single((RpoDigest, Word)),
Multiple(Vec<(RpoDigest, Word)>),
}
impl SmtLeaf {
// CONSTRUCTORS
// ---------------------------------------------------------------------------------------------
/// Returns a new leaf with the specified entries
///
/// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index
/// - Returns an error if 1 or more keys in `entries` map to a leaf index different from
/// `leaf_index`
pub fn new(
entries: Vec<(RpoDigest, Word)>,
leaf_index: LeafIndex<SMT_DEPTH>,
) -> Result<Self, SmtLeafError> {
match entries.len() {
0 => Ok(Self::new_empty(leaf_index)),
1 => {
let (key, value) = entries[0];
let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
if computed_index != leaf_index {
return Err(SmtLeafError::InconsistentSingleLeafIndices {
key,
expected_leaf_index: leaf_index,
actual_leaf_index: computed_index,
});
}
Ok(Self::new_single(key, value))
},
_ => {
let leaf = Self::new_multiple(entries)?;
// `new_multiple()` checked that all keys map to the same leaf index. We still need
// to ensure that leaf index is `leaf_index`.
if leaf.index() != leaf_index {
Err(SmtLeafError::InconsistentMultipleLeafIndices {
leaf_index_from_keys: leaf.index(),
leaf_index_supplied: leaf_index,
})
} else {
Ok(leaf)
}
},
}
}
/// Returns a new empty leaf with the specified leaf index
pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
Self::Empty(leaf_index)
}
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// entry's key.
pub fn new_single(key: RpoDigest, value: Word) -> Self {
Self::Single((key, value))
}
/// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
/// entries' keys.
///
/// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
if entries.len() < 2 {
return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
}
// Check that all keys map to the same leaf index
{
let mut keys = entries.iter().map(|(key, _)| key);
let first_key = *keys.next().expect("ensured at least 2 entries");
let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
for &next_key in keys {
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
if next_leaf_index != first_leaf_index {
return Err(SmtLeafError::InconsistentMultipleLeafKeys {
key_1: first_key,
key_2: next_key,
});
}
}
}
Ok(Self::Multiple(entries))
}
// PUBLIC ACCESSORS
// ---------------------------------------------------------------------------------------------
/// Returns true if the leaf is empty
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty(_))
}
/// Returns the leaf's index in the [`super::Smt`]
pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
match self {
SmtLeaf::Empty(leaf_index) => *leaf_index,
SmtLeaf::Single((key, _)) => key.into(),
SmtLeaf::Multiple(entries) => {
// Note: All keys are guaranteed to have the same leaf index
let (first_key, _) = entries[0];
first_key.into()
},
}
}
/// Returns the number of entries stored in the leaf
pub fn num_entries(&self) -> u64 {
match self {
SmtLeaf::Empty(_) => 0,
SmtLeaf::Single(_) => 1,
SmtLeaf::Multiple(entries) => {
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
},
}
}
/// Computes the hash of the leaf
pub fn hash(&self) -> RpoDigest {
match self {
SmtLeaf::Empty(_) => EMPTY_WORD.into(),
SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
SmtLeaf::Multiple(kvs) => {
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
Rpo256::hash_elements(&elements)
},
}
}
// ITERATORS
// ---------------------------------------------------------------------------------------------
/// Returns the key-value pairs in the leaf
pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
match self {
SmtLeaf::Empty(_) => Vec::new(),
SmtLeaf::Single(kv_pair) => vec![kv_pair],
SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
}
}
// CONVERSIONS
// ---------------------------------------------------------------------------------------------
/// Converts a leaf to a list of field elements
pub fn to_elements(&self) -> Vec<Felt> {
self.clone().into_elements()
}
/// Converts a leaf to a list of field elements
pub fn into_elements(self) -> Vec<Felt> {
self.into_entries().into_iter().flat_map(kv_to_elements).collect()
}
/// Converts a leaf the key-value pairs in the leaf
pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
match self {
SmtLeaf::Empty(_) => Vec::new(),
SmtLeaf::Single(kv_pair) => vec![kv_pair],
SmtLeaf::Multiple(kv_pairs) => kv_pairs,
}
}
// HELPERS
// ---------------------------------------------------------------------------------------------
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
/// leaf.
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
// Ensure that `key` maps to this leaf
if self.index() != key.into() {
return None;
}
match self {
SmtLeaf::Empty(_) => Some(EMPTY_WORD),
SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
if key == key_in_leaf {
Some(*value_in_leaf)
} else {
Some(EMPTY_WORD)
}
},
SmtLeaf::Multiple(kv_pairs) => {
for (key_in_leaf, value_in_leaf) in kv_pairs {
if key == key_in_leaf {
return Some(*value_in_leaf);
}
}
Some(EMPTY_WORD)
},
}
}
/// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
/// any.
///
/// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
match self {
SmtLeaf::Empty(_) => {
*self = SmtLeaf::new_single(key, value);
None
},
SmtLeaf::Single(kv_pair) => {
if kv_pair.0 == key {
// the key is already in this leaf. Update the value and return the previous
// value
let old_value = kv_pair.1;
kv_pair.1 = value;
Some(old_value)
} else {
// Another entry is present in this leaf. Transform the entry into a list
// entry, and make sure the key-value pairs are sorted by key
let mut pairs = vec![*kv_pair, (key, value)];
pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
*self = SmtLeaf::Multiple(pairs);
None
}
},
SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => {
let old_value = kv_pairs[pos].1;
kv_pairs[pos].1 = value;
Some(old_value)
},
Err(pos) => {
kv_pairs.insert(pos, (key, value));
None
},
}
},
}
}
/// Removes key-value pair from the leaf stored at key; returns the previous value associated
/// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
/// empty, and must be removed from the data structure it is contained in.
pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
match self {
SmtLeaf::Empty(_) => (None, false),
SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
if *key_at_leaf == key {
// our key was indeed stored in the leaf, so we return the value that was stored
// in it, and indicate that the leaf should be removed
let old_value = *value_at_leaf;
// Note: this is not strictly needed, since the caller is expected to drop this
// `SmtLeaf` object.
*self = SmtLeaf::new_empty(key.into());
(Some(old_value), true)
} else {
// another key is stored at leaf; nothing to update
(None, false)
}
},
SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => {
let old_value = kv_pairs[pos].1;
kv_pairs.remove(pos);
debug_assert!(!kv_pairs.is_empty());
if kv_pairs.len() == 1 {
// convert the leaf into `Single`
*self = SmtLeaf::Single(kv_pairs[0]);
}
(Some(old_value), false)
},
Err(_) => {
// other keys are stored at leaf; nothing to update
(None, false)
},
}
},
}
}
}
impl Serializable for SmtLeaf {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
// Write: num entries
self.num_entries().write_into(target);
// Write: leaf index
let leaf_index: u64 = self.index().value();
leaf_index.write_into(target);
// Write: entries
for (key, value) in self.entries() {
key.write_into(target);
value.write_into(target);
}
}
}
impl Deserializable for SmtLeaf {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
// Read: num entries
let num_entries = source.read_u64()?;
// Read: leaf index
let leaf_index: LeafIndex<SMT_DEPTH> = {
let value = source.read_u64()?;
LeafIndex::new_max_depth(value)
};
// Read: entries
let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
for _ in 0..num_entries {
let key: RpoDigest = source.read()?;
let value: Word = source.read()?;
entries.push((key, value));
}
Self::new(entries, leaf_index)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Converts a key-value tuple to an iterator of `Felt`s
pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
let key_elements = key.into_iter();
let value_elements = value.into_iter();
key_elements.chain(value_elements)
}
/// Compares two keys, compared element-by-element using their integer representations starting with
/// the most significant element.
pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
let v1 = v1.as_int();
let v2 = v2.as_int();
if v1 != v2 {
return v1.cmp(&v2);
}
}
Ordering::Equal
}

View file

@ -1,548 +0,0 @@
use alloc::{string::ToString, vec::Vec};
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
mod error;
pub use error::{SmtLeafError, SmtProofError};
mod leaf;
pub use leaf::SmtLeaf;
mod proof;
pub use proof::SmtProof;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
// Concurrent implementation
#[cfg(feature = "concurrent")]
mod concurrent;
#[cfg(feature = "internal")]
pub use concurrent::{build_subtree_for_bench, SubtreeLeaf};
#[cfg(test)]
mod tests;
// CONSTANTS
// ================================================================================================
pub const SMT_DEPTH: u8 = 64;
// SMT
// ================================================================================================
type Leaves = super::Leaves<SmtLeaf>;
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements.
///
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf
/// to which the key maps.
///
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
/// second.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
// pub(super) for use in PartialSmt.
pub(super) leaves: Leaves,
inner_nodes: InnerNodes,
}
impl Smt {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The default value used to compute the hash of empty leaves
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<SMT_DEPTH>>::EMPTY_VALUE;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [Smt].
///
/// All leaves in the returned tree are set to [Self::EMPTY_VALUE].
pub fn new() -> Self {
let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
Self {
root,
inner_nodes: Default::default(),
leaves: Default::default(),
}
}
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
///
/// If the `concurrent` feature is enabled, this function uses a parallel implementation to
/// process the entries efficiently, otherwise it defaults to the sequential implementation.
///
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
pub fn with_entries(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
#[cfg(feature = "concurrent")]
{
Self::with_entries_concurrent(entries)
}
#[cfg(not(feature = "concurrent"))]
{
Self::with_entries_sequential(entries)
}
}
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
///
/// This sequential implementation processes entries one at a time to build the tree.
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
#[cfg(any(not(feature = "concurrent"), test))]
fn with_entries_sequential(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
use alloc::collections::BTreeSet;
// create an empty tree
let mut tree = Self::new();
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
// entries with the empty value need additional tracking.
let mut key_set_to_zero = BTreeSet::new();
for (key, value) in entries {
let old_value = tree.insert(key, value);
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
));
}
if value == EMPTY_WORD {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
/// Returns a new [`Smt`] instantiated from already computed leaves and nodes.
///
/// This function performs minimal consistency checking. It is the caller's responsibility to
/// ensure the passed arguments are correct and consistent with each other.
///
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the tree
pub const fn depth(&self) -> u8 {
SMT_DEPTH
}
/// Returns the root of the tree
pub fn root(&self) -> RpoDigest {
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
}
/// Returns the number of non-empty leaves in this tree.
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
/// Returns the leaf to which `key` maps
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
}
/// Returns the value associated with `key`
pub fn get_value(&self, key: &RpoDigest) -> Word {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key)
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
pub fn open(&self, key: &RpoDigest) -> SmtProof {
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
}
/// Returns a boolean value indicating whether the SMT is empty.
pub fn is_empty(&self) -> bool {
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
self.root == Self::EMPTY_ROOT
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [Smt].
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
self.leaves
.iter()
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
}
/// Returns an iterator over the key-value pairs of this [Smt].
pub fn entries(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
self.leaves().flat_map(|(_, leaf)| leaf.entries())
}
/// Returns an iterator over the inner nodes of this [Smt].
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.inner_nodes.values().map(|e| InnerNodeInfo {
value: e.hash(),
left: e.left,
right: e.right,
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`Self::EMPTY_VALUE`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
}
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle
/// tree, or [`drop()`] to discard them.
///
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt = Smt::new();
/// let pair = (RpoDigest::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
#[cfg(feature = "concurrent")]
{
self.compute_mutations_concurrent(kv_pairs)
}
#[cfg(not(feature = "concurrent"))]
{
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}
}
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree
/// and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
// HELPERS
// --------------------------------------------------------------------------------------------
/// Inserts `value` at leaf index pointed to by `key`. `value` is guaranteed to not be the empty
/// value, such that this is indeed an insertion.
fn perform_insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
debug_assert_ne!(value, Self::EMPTY_VALUE);
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
match self.leaves.get_mut(&leaf_index.value()) {
Some(leaf) => leaf.insert(key, value),
None => {
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None
},
}
}
/// Removes key-value pair at leaf index pointed to by `key` if it exists.
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
let (old_value, is_empty) = leaf.remove(key);
if is_empty {
self.leaves.remove(&leaf_index.value());
}
old_value
} else {
// there's nothing stored at the leaf; nothing to update
None
}
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {
type Key = RpoDigest;
type Value = Word;
type Leaf = SmtLeaf;
type Opening = SmtProof;
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn from_raw_parts(
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(root_node.hash(), root);
}
Ok(Self { root, inner_nodes, leaves })
}
fn root(&self) -> RpoDigest {
self.root
}
fn set_root(&mut self, root: RpoDigest) {
self.root = root;
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes
.get(&index)
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node)
}
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.remove(&index)
}
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
// inserting an `EMPTY_VALUE` is equivalent to removing any value associated with `key`
if value != Self::EMPTY_VALUE {
self.perform_insert(key, value)
} else {
self.perform_remove(key)
}
}
fn get_value(&self, key: &Self::Key) -> Self::Value {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
}
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.clone(),
None => SmtLeaf::new_empty(key.into()),
}
}
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest {
leaf.hash()
}
fn construct_prospective_leaf(
&self,
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_int())
}
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
SmtProof::new_unchecked(path, leaf)
}
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf {
assert!(!pairs.is_empty());
if pairs.len() > 1 {
SmtLeaf::new_multiple(pairs).unwrap()
} else {
let (key, value) = pairs.pop().unwrap();
// TODO: should we ever be constructing empty leaves from pairs?
if value == Self::EMPTY_VALUE {
let index = Self::key_to_leaf_index(&key);
SmtLeaf::new_empty(index)
} else {
SmtLeaf::new_single(key, value)
}
}
}
}
impl Default for Smt {
fn default() -> Self {
Self::new()
}
}
// CONVERSIONS
// ================================================================================================
impl From<Word> for LeafIndex<SMT_DEPTH> {
fn from(value: Word) -> Self {
// We use the most significant `Felt` of a `Word` as the leaf index.
Self::new_max_depth(value[3].as_int())
}
}
impl From<RpoDigest> for LeafIndex<SMT_DEPTH> {
fn from(value: RpoDigest) -> Self {
Word::from(value).into()
}
}
impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
fn from(value: &RpoDigest) -> Self {
Word::from(value).into()
}
}
// SERIALIZATION
// ================================================================================================
impl Serializable for Smt {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
// Write the number of filled leaves for this Smt
target.write_usize(self.entries().count());
// Write each (key, value) pair
for (key, value) in self.entries() {
target.write(key);
target.write(value);
}
}
fn get_size_hint(&self) -> usize {
let entries_count = self.entries().count();
// Each entry is the size of a digest plus a word.
entries_count.get_size_hint()
+ entries_count * (RpoDigest::SERIALIZED_SIZE + EMPTY_WORD.get_size_hint())
}
}
impl Deserializable for Smt {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
// Read the number of filled leaves for this Smt
let num_filled_leaves = source.read_usize()?;
let mut entries = Vec::with_capacity(num_filled_leaves);
for _ in 0..num_filled_leaves {
let key = source.read()?;
let value = source.read()?;
entries.push((key, value));
}
Self::with_entries(entries)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
// TESTS
// ================================================================================================
#[test]
fn test_smt_serialization_deserialization() {
// Smt for default types (empty map)
let smt_default = Smt::default();
let bytes = smt_default.to_bytes();
assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap());
assert_eq!(bytes.len(), smt_default.get_size_hint());
// Smt with values
let smt_leaves_2: [(RpoDigest, Word); 2] = [
(
RpoDigest::new([Felt::new(101), Felt::new(102), Felt::new(103), Felt::new(104)]),
[Felt::new(1_u64), Felt::new(2_u64), Felt::new(3_u64), Felt::new(4_u64)],
),
(
RpoDigest::new([Felt::new(105), Felt::new(106), Felt::new(107), Felt::new(108)]),
[Felt::new(5_u64), Felt::new(6_u64), Felt::new(7_u64), Felt::new(8_u64)],
),
];
let smt = Smt::with_entries(smt_leaves_2).unwrap();
let bytes = smt.to_bytes();
assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap());
assert_eq!(bytes.len(), smt.get_size_hint());
}

View file

@ -1,115 +0,0 @@
use alloc::string::ToString;
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
/// [`super::Smt`].
///
/// The proof consists of a Merkle path and leaf which describes the node located at the base of the
/// path.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmtProof {
path: MerklePath,
leaf: SmtLeaf,
}
impl SmtProof {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
///
/// # Errors
/// Returns an error if the path length is not [`SMT_DEPTH`].
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
let depth: usize = SMT_DEPTH.into();
if path.len() != depth {
return Err(SmtProofError::InvalidMerklePathLength(path.len()));
}
Ok(Self { path, leaf })
}
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
///
/// The length of the path is not checked. Reserved for internal use.
pub(super) fn new_unchecked(path: MerklePath, leaf: SmtLeaf) -> Self {
Self { path, leaf }
}
// PROOF VERIFIER
// --------------------------------------------------------------------------------------------
/// Returns true if a [`super::Smt`] with the specified root contains the provided
/// key-value pair.
///
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
/// it does not mean that the provided key-value pair is not in the tree.
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
let maybe_value_in_leaf = self.leaf.get_value(key);
match maybe_value_in_leaf {
Some(value_in_leaf) => {
// The value must match for the proof to be valid
if value_in_leaf != *value {
return false;
}
// make sure the Merkle path resolves to the correct root
self.compute_root() == *root
},
// If the key maps to a different leaf, the proof cannot verify membership of `value`
None => false,
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the value associated with the specific key according to this proof, or None if
/// this proof does not contain a value for the specified key.
///
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
self.leaf.get_value(key)
}
/// Computes the root of a [`super::Smt`] to which this proof resolves.
pub fn compute_root(&self) -> RpoDigest {
self.path
.compute_root(self.leaf.index().value(), self.leaf.hash())
.expect("failed to compute Merkle path root")
}
/// Returns the proof's Merkle path.
pub fn path(&self) -> &MerklePath {
&self.path
}
/// Returns the leaf associated with the proof.
pub fn leaf(&self) -> &SmtLeaf {
&self.leaf
}
/// Consume the proof and returns its parts.
pub fn into_parts(self) -> (MerklePath, SmtLeaf) {
(self.path, self.leaf)
}
}
impl Serializable for SmtProof {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.path.write_into(target);
self.leaf.write_into(target);
}
}
impl Deserializable for SmtProof {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let path = MerklePath::read_from(source)?;
let leaf = SmtLeaf::read_from(source)?;
Self::new(path, leaf).map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}

View file

@ -1,724 +0,0 @@
use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{
smt::{NodeMutation, SparseMerkleTree, UnorderedMap},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE,
};
// SMT
// --------------------------------------------------------------------------------------------
/// This test checks that inserting twice at the same key functions as expected. The test covers
/// only the case where the key is alone in its leaf
#[test]
fn test_smt_insert_at_same_key() {
let mut smt = Smt::default();
let mut store: MerkleStore = MerkleStore::default();
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
let key_1: RpoDigest = {
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
let value_1 = [ONE; WORD_SIZE];
let value_2 = [ONE + ONE; WORD_SIZE];
// Insert value 1 and ensure root is as expected
{
let leaf_node = build_empty_or_single_leaf_node(key_1, value_1);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_1 = smt.insert(key_1, value_1);
assert_eq!(old_value_1, EMPTY_WORD);
assert_eq!(smt.root(), tree_root);
}
// Insert value 2 and ensure root is as expected
{
let leaf_node = build_empty_or_single_leaf_node(key_1, value_2);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_2 = smt.insert(key_1, value_2);
assert_eq!(old_value_2, value_1);
assert_eq!(smt.root(), tree_root);
}
}
/// This test checks that inserting twice at the same key functions as expected. The test covers
/// only the case where the leaf type is `SmtLeaf::Multiple`
#[test]
fn test_smt_insert_at_same_key_2() {
// The most significant u64 used for both keys (to ensure they map to the same leaf)
let key_msb: u64 = 42;
let key_already_present: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(key_msb)]);
let key_already_present_index: NodeIndex =
LeafIndex::<SMT_DEPTH>::from(key_already_present).into();
let value_already_present = [ONE + ONE + ONE; WORD_SIZE];
let mut smt =
Smt::with_entries(core::iter::once((key_already_present, value_already_present))).unwrap();
let mut store: MerkleStore = {
let mut store = MerkleStore::default();
let leaf_node = build_empty_or_single_leaf_node(key_already_present, value_already_present);
store
.set_node(*EmptySubtreeRoots::entry(SMT_DEPTH, 0), key_already_present_index, leaf_node)
.unwrap();
store
};
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(key_msb)]);
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
assert_eq!(key_1_index, key_already_present_index);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [ONE + ONE; WORD_SIZE];
// Insert value 1 and ensure root is as expected
{
// Note: key_1 comes first because it is smaller
let leaf_node = build_multiple_leaf_node(&[
(key_1, value_1),
(key_already_present, value_already_present),
]);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_1 = smt.insert(key_1, value_1);
assert_eq!(old_value_1, EMPTY_WORD);
assert_eq!(smt.root(), tree_root);
}
// Insert value 2 and ensure root is as expected
{
let leaf_node = build_multiple_leaf_node(&[
(key_1, value_2),
(key_already_present, value_already_present),
]);
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
let old_value_2 = smt.insert(key_1, value_2);
assert_eq!(old_value_2, value_1);
assert_eq!(smt.root(), tree_root);
}
}
/// This test ensures that the root of the tree is as expected when we add/remove 3 items at 3
/// different keys. This also tests that the merkle paths produced are as expected.
#[test]
fn test_smt_insert_and_remove_multiple_values() {
fn insert_values_and_assert_path(
smt: &mut Smt,
store: &mut MerkleStore,
key_values: &[(RpoDigest, Word)],
) {
for &(key, value) in key_values {
let key_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key).into();
let leaf_node = build_empty_or_single_leaf_node(key, value);
let tree_root = store.set_node(smt.root(), key_index, leaf_node).unwrap().root;
let _ = smt.insert(key, value);
assert_eq!(smt.root(), tree_root);
let expected_path = store.get_path(tree_root, key_index).unwrap();
assert_eq!(smt.open(&key).into_parts().0, expected_path.path);
}
}
let mut smt = Smt::default();
let mut store: MerkleStore = MerkleStore::default();
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
let key_1: RpoDigest = {
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let key_2: RpoDigest = {
let raw = 0b_11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let key_3: RpoDigest = {
let raw = 0b_00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000_u64;
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
};
let value_1 = [ONE; WORD_SIZE];
let value_2 = [ONE + ONE; WORD_SIZE];
let value_3 = [ONE + ONE + ONE; WORD_SIZE];
// Insert values in the tree
let key_values = [(key_1, value_1), (key_2, value_2), (key_3, value_3)];
insert_values_and_assert_path(&mut smt, &mut store, &key_values);
// Remove values from the tree
let key_empty_values = [(key_1, EMPTY_WORD), (key_2, EMPTY_WORD), (key_3, EMPTY_WORD)];
insert_values_and_assert_path(&mut smt, &mut store, &key_empty_values);
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
assert_eq!(smt.root(), empty_root);
// an empty tree should have no leaves or inner nodes
assert!(smt.leaves.is_empty());
assert!(smt.inner_nodes.is_empty());
}
/// This tests that inserting the empty value does indeed remove the key-value contained at the
/// leaf. We insert & remove 3 values at the same leaf to ensure that all cases are covered (empty,
/// single, multiple).
#[test]
fn test_smt_removal() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
let key_3: RpoDigest =
RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
// insert key-value 1
{
let old_value_1 = smt.insert(key_1, value_1);
assert_eq!(old_value_1, EMPTY_WORD);
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1)));
}
// insert key-value 2
{
let old_value_2 = smt.insert(key_2, value_2);
assert_eq!(old_value_2, EMPTY_WORD);
assert_eq!(
smt.get_leaf(&key_2),
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
);
}
// insert key-value 3
{
let old_value_3 = smt.insert(key_3, value_3);
assert_eq!(old_value_3, EMPTY_WORD);
assert_eq!(
smt.get_leaf(&key_3),
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)])
);
}
// remove key 3
{
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
assert_eq!(old_value_3, value_3);
assert_eq!(
smt.get_leaf(&key_3),
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
);
}
// remove key 2
{
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
assert_eq!(old_value_2, value_2);
assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1)));
}
// remove key 1
{
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
assert_eq!(old_value_1, value_1);
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into()));
}
}
/// This tests that we can correctly calculate prospective leaves -- that is, we can construct
/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or
/// cloning the tree.
#[test]
fn test_prospective_hash() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
// insert key-value 1
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash();
smt.insert(key_1, value_1);
let leaf = smt.get_leaf(&key_1);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 2
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash();
smt.insert(key_2, value_2);
let leaf = smt.get_leaf(&key_2);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 3
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash();
smt.insert(key_3, value_3);
let leaf = smt.get_leaf(&key_3);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// remove key 3
{
let old_leaf = smt.get_leaf(&key_3);
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
assert_eq!(old_value_3, value_3);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 2
{
let old_leaf = smt.get_leaf(&key_2);
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
assert_eq!(old_value_2, value_2);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 1
{
let old_leaf = smt.get_leaf(&key_1);
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
assert_eq!(old_value_1, value_1);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
}
/// This tests that we can perform prospective changes correctly.
#[test]
fn test_prospective_insertion() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
let root_empty = smt.root();
let root_1 = {
smt.insert(key_1, value_1);
smt.root()
};
let root_2 = {
smt.insert(key_2, value_2);
smt.root()
};
let root_3 = {
smt.insert(key_3, value_3);
smt.root()
};
// Test incremental updates.
let mut smt = Smt::default();
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
UnorderedMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
revert.node_mutations,
smt.inner_nodes.keys().map(|key| (*key, NodeMutation::Removal)).collect(),
"reverse mutations inner nodes did not match"
);
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
// Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3);
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_3);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
UnorderedMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);
// Test batch updates, and that the order doesn't matter.
let pairs =
vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(
mutations.root(),
root_empty,
"prospective root for batch removal did not match actual root",
);
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
}
#[test]
fn test_mutations_revert() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let original = smt.clone();
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), original.root(), "reverse mutations new root did not match");
smt.apply_mutations(revert).unwrap();
assert_eq!(smt, original, "SMT with applied revert mutations did not match original SMT");
}
#[test]
fn test_mutation_set_serialization() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let serialized = mutations.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, mutations, "deserialized mutations did not match original");
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
let serialized = revert.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, revert, "deserialized mutations did not match original");
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
assert_eq!(smt.open(&key_1), smt.open(&key_2));
}
/// Tests that an empty leaf hashes to the empty word
#[test]
fn test_empty_leaf_hash() {
let smt = Smt::default();
let leaf = smt.get_leaf(&RpoDigest::default());
assert_eq!(leaf.hash(), EMPTY_WORD.into());
}
/// Tests that `get_value()` works as expected
#[test]
fn test_smt_get_value() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
let returned_value_1 = smt.get_value(&key_1);
let returned_value_2 = smt.get_value(&key_2);
assert_eq!(value_1, returned_value_1);
assert_eq!(value_2, returned_value_2);
// Check that a key with no inserted value returns the empty word
let key_no_value = RpoDigest::from([42_u32, 42_u32, 42_u32, 42_u32]);
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
}
/// Tests that `entries()` works as expected
#[test]
fn test_smt_entries() {
let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let entries = [(key_1, value_1), (key_2, value_2)];
let smt = Smt::with_entries(entries).unwrap();
let mut expected = Vec::from_iter(entries);
expected.sort_by_key(|(k, _)| *k);
let mut actual: Vec<_> = smt.entries().cloned().collect();
actual.sort_by_key(|(k, _)| *k);
assert_eq!(actual, expected);
}
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of
/// depth 64
#[test]
fn test_smt_check_empty_root_constant() {
// get the root of the empty tree of depth 64
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
assert_eq!(empty_root_64_depth, Smt::EMPTY_ROOT);
}
// SMT LEAF
// --------------------------------------------------------------------------------------------
#[test]
fn test_empty_smt_leaf_serialization() {
let empty_leaf = SmtLeaf::new_empty(LeafIndex::new_max_depth(42));
let mut serialized = empty_leaf.to_bytes();
// extend buffer with random bytes
serialized.extend([1, 2, 3, 4, 5]);
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
assert_eq!(empty_leaf, deserialized);
}
#[test]
fn test_single_smt_leaf_serialization() {
let single_leaf = SmtLeaf::new_single(
RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
);
let mut serialized = single_leaf.to_bytes();
// extend buffer with random bytes
serialized.extend([1, 2, 3, 4, 5]);
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
assert_eq!(single_leaf, deserialized);
}
#[test]
fn test_multiple_smt_leaf_serialization_success() {
let multiple_leaf = SmtLeaf::new_multiple(vec![
(
RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
),
(
RpoDigest::from([100_u32, 101_u32, 102_u32, 13_u32]),
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
),
])
.unwrap();
let mut serialized = multiple_leaf.to_bytes();
// extend buffer with random bytes
serialized.extend([1, 2, 3, 4, 5]);
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
assert_eq!(multiple_leaf, deserialized);
}
// HELPERS
// --------------------------------------------------------------------------------------------
fn build_empty_or_single_leaf_node(key: RpoDigest, value: Word) -> RpoDigest {
if value == EMPTY_WORD {
SmtLeaf::new_empty(key.into()).hash()
} else {
SmtLeaf::Single((key, value)).hash()
}
}
fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
let elements: Vec<Felt> = kv_pairs
.iter()
.flat_map(|(key, value)| {
let key_elements = key.into_iter();
let value_elements = (*value).into_iter();
key_elements.chain(value_elements)
})
.collect();
Rpo256::hash_elements(&elements)
}
/// Applies mutations with and without reversion to the given SMT, comparing resulting SMTs,
/// returning mutation set for reversion.
fn apply_mutations(
smt: &mut Smt,
mutation_set: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
let mut smt2 = smt.clone();
let reversion = smt.apply_mutations_with_reversion(mutation_set.clone()).unwrap();
smt2.apply_mutations(mutation_set).unwrap();
assert_eq!(&smt2, smt);
reversion
}

View file

@ -1,712 +0,0 @@
use alloc::vec::Vec;
use core::hash::Hash;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{
hash::rpo::{Rpo256, RpoDigest},
Felt, Word, EMPTY_WORD,
};
mod full;
#[cfg(feature = "internal")]
pub use full::{build_subtree_for_bench, SubtreeLeaf};
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
mod simple;
pub use simple::SimpleSmt;
mod partial;
pub use partial::PartialSmt;
// CONSTANTS
// ================================================================================================
/// Minimum supported depth.
pub const SMT_MIN_DEPTH: u8 = 1;
/// Maximum supported depth.
pub const SMT_MAX_DEPTH: u8 = 64;
// SPARSE MERKLE TREE
// ================================================================================================
/// A map whose keys are not guarantied to be ordered.
#[cfg(feature = "smt_hashmaps")]
type UnorderedMap<K, V> = hashbrown::HashMap<K, V>;
#[cfg(not(feature = "smt_hashmaps"))]
type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
type Leaves<T> = UnorderedMap<u64, T>;
type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
/// An abstract description of a sparse Merkle tree.
///
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
/// value was not explicitly set, then its value is the default value. Typically, the vast majority
/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
/// representation of the tree will only keep track of the leaves that have a different value from
/// the default.
///
/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
/// tree is just a single value, and is probably a programming mistake.
///
/// Every key maps to one leaf. If there are as many keys as there are leaves, then
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
/// must accommodate all keys that map to the same leaf.
///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone + Ord + Eq + Hash;
/// The type for a value
type Value: Clone + PartialEq;
/// The type for a leaf
type Leaf: Clone;
/// The type for an opening (i.e. a "proof") of a leaf
type Opening;
/// The default value used to compute the hash of empty leaves
const EMPTY_VALUE: Self::Value;
/// The root of the empty tree with provided DEPTH
const EMPTY_ROOT: RpoDigest;
// PROVIDED METHODS
// ---------------------------------------------------------------------------------------------
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
fn open(&self, key: &Self::Key) -> Self::Opening {
let leaf = self.get_leaf(key);
let mut index: NodeIndex = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
leaf_index.into()
};
let merkle_path = {
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let value = if is_right { left } else { right };
path.push(value);
}
MerklePath::new(path)
};
Self::path_and_leaf_to_opening(merkle_path, leaf)
}
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`Self::EMPTY_VALUE`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return value;
}
let leaf = self.get_leaf(&key);
let node_index = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
leaf_index.into()
};
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
old_value
}
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
/// `node_hash_at_index` is the hash of the node stored at index.
fn recompute_nodes_from_index_to_root(
&mut self,
mut index: NodeIndex,
node_hash_at_index: RpoDigest,
) {
let mut node_hash = node_hash_at_index;
for node_depth in (0..index.depth()).rev() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
(node_hash, right)
};
node_hash = Rpo256::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, then can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index);
} else {
self.insert_inner_node(index, InnerNode { left, right });
}
}
self.set_root(node_hash);
}
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
/// the Merkle tree, or [`drop()`] to discard them.
fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
self.compute_mutations_sequential(kv_pairs)
}
/// Sequential version of [`SparseMerkleTree::compute_mutations()`].
/// This is the default implementation.
fn compute_mutations_sequential(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: UnorderedMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: NodeMutations = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
// to check the key-value pairs we've already seen to get the "effective" old value.
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
// part of the "current leaf".
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
// Most of the time `pairs_at_index` should only contain a single entry (or
// none at all), as multi-leaves should be really rare.
let existing_leaf = acc.clone();
self.construct_prospective_leaf(existing_leaf, k, v)
})
};
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
// Whether the node we're replacing is the right child or the left child.
let is_right = node_index.is_value_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
// The next iteration will operate on this new node's hash.
new_child_hash = new_node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
}
// Once we're at depth 0, the last node we made is the new root.
new_root = new_child_hash;
// And then we're done with this pair; on to the next one.
new_pairs.insert(key, value);
}
MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
}
}
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
for (index, mutation) in node_mutations {
match mutation {
Removal => {
self.remove_inner_node(index);
},
Addition(node) => {
self.insert_inner_node(index, node);
},
}
}
for (key, value) in new_pairs {
self.insert_value(key, value);
}
self.set_root(new_root);
Ok(())
}
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
/// updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
let mut reverse_mutations = NodeMutations::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
if let Some(node) = self.remove_inner_node(index) {
reverse_mutations.insert(index, Addition(node));
}
},
Addition(node) => {
if let Some(old_node) = self.insert_inner_node(index, node) {
reverse_mutations.insert(index, Addition(old_node));
} else {
reverse_mutations.insert(index, Removal);
}
},
}
}
let mut reverse_pairs = UnorderedMap::new();
for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
} else {
reverse_pairs.insert(key, Self::EMPTY_VALUE);
}
}
self.set_root(new_root);
Ok(MutationSet {
old_root: new_root,
node_mutations: reverse_mutations,
new_pairs: reverse_pairs,
new_root: old_root,
})
}
// REQUIRED METHODS
// ---------------------------------------------------------------------------------------------
/// Construct this type from already computed leaves and nodes. The caller ensures passed
/// arguments are correct and consistent with each other.
fn from_raw_parts(
inner_nodes: InnerNodes,
leaves: Leaves<Self::Leaf>,
root: RpoDigest,
) -> Result<Self, MerkleError>
where
Self: Sized;
/// The root of the tree
fn root(&self) -> RpoDigest;
/// Sets the root of the tree
fn set_root(&mut self, root: RpoDigest);
/// Retrieves an inner node at the given index
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
/// Inserts an inner node at the given index
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
/// Removes an inner node at the given index
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
/// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
/// Returns the value at the specified key. Recall that by definition, any key that hasn't been
/// updated is associated with [`Self::EMPTY_VALUE`].
fn get_value(&self, key: &Self::Key) -> Self::Value;
/// Returns the leaf at the specified index.
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
/// Returns the hash of a leaf
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
/// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
/// mutating the tree itself. The existing leaf can be empty.
///
/// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
/// as the argument for `existing_leaf`. The return value from this function can be chained back
/// into this function as the first argument to continue making prospective changes.
///
/// # Invariants
/// Because this method is for a prospective key-value insertion into a specific leaf,
/// `existing_leaf` must have the same leaf index as `key` (as determined by
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
fn construct_prospective_leaf(
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Self::Leaf;
/// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
/// Constructs a single leaf from an arbitrary amount of key-value pairs.
/// Those pairs must all have the same leaf index.
fn pairs_to_leaf(pairs: Vec<(Self::Key, Self::Value)>) -> Self::Leaf;
/// Maps a (MerklePath, Self::Leaf) to an opening.
///
/// The length `path` is guaranteed to be equal to `DEPTH`
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
}
// INNER NODE
// ================================================================================================
/// This struct is public so functions returning it can be used in `benches/`, but is otherwise not
/// part of the public API.
#[doc(hidden)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
impl InnerNode {
pub fn hash(&self) -> RpoDigest {
Rpo256::merge(&[self.left, self.right])
}
}
// LEAF INDEX
// ================================================================================================
/// The index of a leaf, at a depth known at compile-time.
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct LeafIndex<const DEPTH: u8> {
index: NodeIndex,
}
impl<const DEPTH: u8> LeafIndex<DEPTH> {
pub fn new(value: u64) -> Result<Self, MerkleError> {
if DEPTH < SMT_MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(DEPTH));
}
Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
}
pub fn value(&self) -> u64 {
self.index.value()
}
}
impl LeafIndex<SMT_MAX_DEPTH> {
pub const fn new_max_depth(value: u64) -> Self {
LeafIndex {
index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
}
}
}
impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
fn from(value: LeafIndex<DEPTH>) -> Self {
value.index
}
}
impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
type Error = MerkleError;
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
if node_index.depth() != DEPTH {
return Err(MerkleError::InvalidNodeIndexDepth {
expected: DEPTH,
provided: node_index.depth(),
});
}
Self::new(node_index.value())
}
}
impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.index.write_into(target);
}
}
impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
Ok(Self { index: source.read()? })
}
}
// MUTATIONS
// ================================================================================================
/// A change to an inner node of a sparse Merkle tree that hasn't yet been applied.
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NodeMutation {
/// Node needs to be removed.
Removal,
/// Node needs to be inserted.
Addition(InnerNode),
}
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
old_root: RpoDigest,
/// The set of nodes that need to be removed or added. The "effective" node at an index is the
/// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: NodeMutations,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs: UnorderedMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest,
}
impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
self.new_root
}
/// Returns the SMT root before the mutations were applied.
pub fn old_root(&self) -> RpoDigest {
self.old_root
}
/// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations(&self) -> &NodeMutations {
&self.node_mutations
}
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs(&self) -> &UnorderedMap<K, V> {
&self.new_pairs
}
}
// SERIALIZATION
// ================================================================================================
impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.left);
target.write(self.right);
}
}
impl Deserializable for InnerNode {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let left = source.read()?;
let right = source.read()?;
Ok(Self { left, right })
}
}
impl Serializable for NodeMutation {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
match self {
NodeMutation::Removal => target.write_bool(false),
NodeMutation::Addition(inner_node) => {
target.write_bool(true);
inner_node.write_into(target);
},
}
}
}
impl Deserializable for NodeMutation {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
if source.read_bool()? {
let inner_node = source.read()?;
return Ok(NodeMutation::Addition(inner_node));
}
Ok(NodeMutation::Removal)
}
}
impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
for MutationSet<DEPTH, K, V>
{
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
let inner_removals: Vec<_> = self
.node_mutations
.iter()
.filter(|(_, value)| matches!(value, NodeMutation::Removal))
.map(|(key, _)| key)
.collect();
let inner_additions: Vec<_> = self
.node_mutations
.iter()
.filter_map(|(key, value)| match value {
NodeMutation::Addition(node) => Some((key, node)),
_ => None,
})
.collect();
target.write(inner_removals);
target.write(inner_additions);
target.write_usize(self.new_pairs.len());
target.write_many(&self.new_pairs);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let inner_removals: Vec<NodeIndex> = source.read()?;
let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
let node_mutations = NodeMutations::from_iter(
inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
inner_additions
.into_iter()
.map(|(index, node)| (index, NodeMutation::Addition(node))),
),
);
let num_new_pairs = source.read_usize()?;
let new_pairs = source.read_many(num_new_pairs)?;
let new_pairs = UnorderedMap::from_iter(new_pairs);
Ok(Self {
old_root,
node_mutations,
new_pairs,
new_root,
})
}
}

View file

@ -1,361 +0,0 @@
use crate::{
hash::rpo::RpoDigest,
merkle::{smt::SparseMerkleTree, InnerNode, MerkleError, MerklePath, Smt, SmtLeaf, SmtProof},
Word, EMPTY_WORD,
};
/// A partial version of an [`Smt`].
///
/// This type can track a subset of the key-value pairs of a full [`Smt`] and allows for updating
/// those pairs to compute the new root of the tree, as if the updates had been done on the full
/// tree. This is useful so that not all leaves have to be present and loaded into memory to compute
/// an update.
///
/// To facilitate this, a partial SMT requires that the merkle paths of every key-value pair are
/// added to the tree. This means this pair is considered "tracked" and can be updated.
///
/// An important caveat is that only pairs whose merkle paths were added can be updated. Attempting
/// to update an untracked value will result in an error. See [`PartialSmt::insert`] for more
/// details.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct PartialSmt(Smt);
impl PartialSmt {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [`PartialSmt`].
///
/// All leaves in the returned tree are set to [`Smt::EMPTY_VALUE`].
pub fn new() -> Self {
Self(Smt::new())
}
/// Instantiates a new [`PartialSmt`] by calling [`PartialSmt::add_path`] for all [`SmtProof`]s
/// in the provided iterator.
///
/// # Errors
///
/// Returns an error if:
/// - the new root after the insertion of a (leaf, path) tuple does not match the existing root
/// (except if the tree was previously empty).
pub fn from_proofs<I>(paths: I) -> Result<Self, MerkleError>
where
I: IntoIterator<Item = SmtProof>,
{
let mut partial_smt = Self::new();
for (leaf, path) in paths.into_iter().map(SmtProof::into_parts) {
partial_smt.add_path(path, leaf)?;
}
Ok(partial_smt)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the root of the tree.
pub fn root(&self) -> RpoDigest {
self.0.root()
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
///
/// # Errors
///
/// Returns an error if:
/// - the key is not tracked by this partial SMT.
pub fn open(&self, key: &RpoDigest) -> Result<SmtProof, MerkleError> {
if !self.is_leaf_tracked(key) {
return Err(MerkleError::UntrackedKey(*key));
}
Ok(self.0.open(key))
}
/// Returns the leaf to which `key` maps
///
/// # Errors
///
/// Returns an error if:
/// - the key is not tracked by this partial SMT.
pub fn get_leaf(&self, key: &RpoDigest) -> Result<SmtLeaf, MerkleError> {
if !self.is_leaf_tracked(key) {
return Err(MerkleError::UntrackedKey(*key));
}
Ok(self.0.get_leaf(key))
}
/// Returns the value associated with `key`.
///
/// # Errors
///
/// Returns an error if:
/// - the key is not tracked by this partial SMT.
pub fn get_value(&self, key: &RpoDigest) -> Result<Word, MerkleError> {
if !self.is_leaf_tracked(key) {
return Err(MerkleError::UntrackedKey(*key));
}
Ok(self.0.get_value(key))
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`Smt::EMPTY_VALUE`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
///
/// # Errors
///
/// Returns an error if:
/// - the key and its merkle path were not previously added (using [`PartialSmt::add_path`]) to
/// this [`PartialSmt`], which means it is almost certainly incorrect to update its value. If
/// an error is returned the tree is in the same state as before.
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Result<Word, MerkleError> {
if !self.is_leaf_tracked(&key) {
return Err(MerkleError::UntrackedKey(key));
}
let previous_value = self.0.insert(key, value);
// If the value was removed the SmtLeaf was removed as well by the underlying Smt
// implementation. However, we still want to consider that leaf tracked so it can be
// read and written to, so we reinsert an empty SmtLeaf.
if value == EMPTY_WORD {
let leaf_index = Smt::key_to_leaf_index(&key);
self.0.leaves.insert(leaf_index.value(), SmtLeaf::Empty(leaf_index));
}
Ok(previous_value)
}
/// Adds a leaf and its merkle path to this [`PartialSmt`] and returns the value that
/// was previously present at this key, if any.
///
/// If this function was called, the `key` can subsequently be updated to a new value and
/// produce a correct new tree root.
///
/// # Errors
///
/// Returns an error if:
/// - the new root after the insertion of the leaf and the path does not match the existing root
/// (except if the tree was previously empty). If an error is returned, the tree is left in an
/// inconsistent state.
pub fn add_path(&mut self, leaf: SmtLeaf, path: MerklePath) -> Result<(), MerkleError> {
let mut current_index = leaf.index().index;
let mut node_hash_at_current_index = leaf.hash();
// We insert directly into the leaves for two reasons:
// - We can directly insert the leaf as it is without having to loop over its entries to
// call Smt::perform_insert.
// - If the leaf is SmtLeaf::Empty, we will also insert it, which means this leaf is
// considered tracked by the partial SMT as it is part of the leaves map. When calling
// PartialSmt::insert, this will not error for such empty leaves whose merkle path was
// added, but will error for otherwise non-existent leaves whose paths were not added,
// which is what we want.
self.0.leaves.insert(current_index.value(), leaf);
for sibling_hash in path {
// Find the index of the sibling node and compute whether it is a left or right child.
let is_sibling_right = current_index.sibling().is_value_odd();
// Move the index up so it points to the parent of the current index and the sibling.
current_index.move_up();
// Construct the new parent node from the child that was updated and the sibling from
// the merkle path.
let new_parent_node = if is_sibling_right {
InnerNode {
left: node_hash_at_current_index,
right: sibling_hash,
}
} else {
InnerNode {
left: sibling_hash,
right: node_hash_at_current_index,
}
};
self.0.insert_inner_node(current_index, new_parent_node);
node_hash_at_current_index = self.0.get_inner_node(current_index).hash();
}
// Check the newly added merkle path is consistent with the existing tree. If not, the
// merkle path was invalid or computed from another tree.
// We skip this check if the root is empty since this indicates we're adding the first
// merkle path in which case we have to update the tree root to the root from the path.
if self.root() != Smt::EMPTY_ROOT && self.root() != node_hash_at_current_index {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: node_hash_at_current_index,
});
}
self.0.set_root(node_hash_at_current_index);
Ok(())
}
/// Returns true if the key's merkle path was previously added to this partial SMT and can be
/// sensibly updated to a new value.
/// In particular, this returns true for keys whose value was empty **but** their merkle paths
/// were added, while it returns false if the merkle paths were **not** added.
fn is_leaf_tracked(&self, key: &RpoDigest) -> bool {
self.0.leaves.contains_key(&Smt::key_to_leaf_index(key).value())
}
}
impl Default for PartialSmt {
fn default() -> Self {
Self::new()
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use rand_utils::rand_array;
use super::*;
use crate::{EMPTY_WORD, ONE, ZERO};
/// Tests that a basic PartialSmt can be built from a full one and that inserting or removing
/// values whose merkle path were added to the partial SMT results in the same root as the
/// equivalent update in the full tree.
#[test]
fn partial_smt_insert_and_remove() {
let key0 = RpoDigest::from(Word::from(rand_array()));
let key1 = RpoDigest::from(Word::from(rand_array()));
let key2 = RpoDigest::from(Word::from(rand_array()));
// A key for which we won't add a value so it will be empty.
let key_empty = RpoDigest::from(Word::from(rand_array()));
let value0 = Word::from(rand_array());
let value1 = Word::from(rand_array());
let value2 = Word::from(rand_array());
let mut kv_pairs = vec![(key0, value0), (key1, value1), (key2, value2)];
// Add more random leaves.
kv_pairs.reserve(1000);
for _ in 0..1000 {
let key = RpoDigest::from(Word::from(rand_array()));
let value = Word::from(rand_array());
kv_pairs.push((key, value));
}
let mut full = Smt::with_entries(kv_pairs).unwrap();
// Constructing a partial SMT from proofs succeeds.
// ----------------------------------------------------------------------------------------
let proof0 = full.open(&key0);
let proof2 = full.open(&key2);
let proof_empty = full.open(&key_empty);
assert!(proof_empty.leaf().is_empty());
let mut partial = PartialSmt::from_proofs([proof0, proof2, proof_empty]).unwrap();
assert_eq!(full.root(), partial.root());
assert_eq!(partial.get_value(&key0).unwrap(), value0);
let error = partial.get_value(&key1).unwrap_err();
assert_matches!(error, MerkleError::UntrackedKey(_));
assert_eq!(partial.get_value(&key2).unwrap(), value2);
// Insert new values for added keys with empty and non-empty values.
// ----------------------------------------------------------------------------------------
let new_value0 = Word::from(rand_array());
let new_value2 = Word::from(rand_array());
// A non-empty value for the key that was previously empty.
let new_value_empty_key = Word::from(rand_array());
full.insert(key0, new_value0);
full.insert(key2, new_value2);
full.insert(key_empty, new_value_empty_key);
partial.insert(key0, new_value0).unwrap();
partial.insert(key2, new_value2).unwrap();
// This updates a key whose value was previously empty.
partial.insert(key_empty, new_value_empty_key).unwrap();
assert_eq!(full.root(), partial.root());
assert_eq!(partial.get_value(&key0).unwrap(), new_value0);
assert_eq!(partial.get_value(&key2).unwrap(), new_value2);
assert_eq!(partial.get_value(&key_empty).unwrap(), new_value_empty_key);
// Remove an added key.
// ----------------------------------------------------------------------------------------
full.insert(key0, EMPTY_WORD);
partial.insert(key0, EMPTY_WORD).unwrap();
assert_eq!(full.root(), partial.root());
assert_eq!(partial.get_value(&key0).unwrap(), EMPTY_WORD);
// Check if returned openings are the same in partial and full SMT.
// ----------------------------------------------------------------------------------------
// This is a key whose value is empty since it was removed.
assert_eq!(full.open(&key0), partial.open(&key0).unwrap());
// This is a key whose value is non-empty.
assert_eq!(full.open(&key2), partial.open(&key2).unwrap());
// Attempting to update a key whose merkle path was not added is an error.
// ----------------------------------------------------------------------------------------
let error = partial.clone().insert(key1, Word::from(rand_array())).unwrap_err();
assert_matches!(error, MerkleError::UntrackedKey(_));
let error = partial.insert(key1, EMPTY_WORD).unwrap_err();
assert_matches!(error, MerkleError::UntrackedKey(_));
}
/// Test that we can add an SmtLeaf::Multiple variant to a partial SMT.
#[test]
fn partial_smt_multiple_leaf_success() {
// key0 and key1 have the same felt at index 3 so they will be placed in the same leaf.
let key0 = RpoDigest::from(Word::from([ZERO, ZERO, ZERO, ONE]));
let key1 = RpoDigest::from(Word::from([ONE, ONE, ONE, ONE]));
let key2 = RpoDigest::from(Word::from(rand_array()));
let value0 = Word::from(rand_array());
let value1 = Word::from(rand_array());
let value2 = Word::from(rand_array());
let full = Smt::with_entries([(key0, value0), (key1, value1), (key2, value2)]).unwrap();
// Make sure our assumption about the leaf being a multiple is correct.
let SmtLeaf::Multiple(_) = full.get_leaf(&key0) else {
panic!("expected full tree to produce multiple leaf")
};
let proof0 = full.open(&key0);
let proof2 = full.open(&key2);
let partial = PartialSmt::from_proofs([proof0, proof2]).unwrap();
assert_eq!(partial.root(), full.root());
assert_eq!(partial.get_leaf(&key0).unwrap(), full.get_leaf(&key0));
// key1 is present in the partial tree because it is part of the proof of key0.
assert_eq!(partial.get_leaf(&key1).unwrap(), full.get_leaf(&key1));
assert_eq!(partial.get_leaf(&key2).unwrap(), full.get_leaf(&key2));
}
}

View file

@ -1,425 +0,0 @@
use alloc::{collections::BTreeSet, vec::Vec};
use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex,
MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
#[cfg(test)]
mod tests;
// SPARSE MERKLE TREE
// ================================================================================================
type Leaves = super::Leaves<Word>;
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest,
inner_nodes: InnerNodes,
leaves: Leaves,
}
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The default value used to compute the hash of empty leaves
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [SimpleSmt].
///
/// All leaves in the returned tree are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if DEPTH is 0 or is greater than 64.
pub fn new() -> Result<Self, MerkleError> {
// validate the range of the depth.
if DEPTH < SMT_MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(DEPTH));
} else if SMT_MAX_DEPTH < DEPTH {
return Err(MerkleError::DepthTooBig(DEPTH as u64));
}
let root = *EmptySubtreeRoots::entry(DEPTH, 0);
Ok(Self {
root,
inner_nodes: Default::default(),
leaves: Default::default(),
})
}
/// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
///
/// All leaves omitted from the entries list are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if:
/// - If the depth is 0 or is greater than 64.
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
/// - The provided entries contain multiple values for the same key.
pub fn with_leaves(
entries: impl IntoIterator<Item = (u64, Word)>,
) -> Result<Self, MerkleError> {
// create an empty tree
let mut tree = Self::new()?;
// compute the max number of entries. We use an upper bound of depth 63 because we consider
// passing in a vector of size 2^64 infeasible.
let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
// entries with the empty value need additional tracking.
let mut key_set_to_zero = BTreeSet::new();
for (idx, (key, value)) in entries.into_iter().enumerate() {
if idx >= max_num_entries {
return Err(MerkleError::TooManyEntries(max_num_entries));
}
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(key));
}
if value == Self::EMPTY_VALUE {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
/// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
///
/// This function performs minimal consistency checking. It is the caller's responsibility to
/// ensure the passed arguments are correct and consistent with each other.
///
/// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`.
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
}
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
/// starting at index 0.
pub fn with_contiguous_leaves(
entries: impl IntoIterator<Item = Word>,
) -> Result<Self, MerkleError> {
Self::with_leaves(
entries
.into_iter()
.enumerate()
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the tree
pub const fn depth(&self) -> u8 {
DEPTH
}
/// Returns the root of the tree
pub fn root(&self) -> RpoDigest {
<Self as SparseMerkleTree<DEPTH>>::root(self)
}
/// Returns the number of non-empty leaves in this tree.
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
/// Returns the leaf at the specified index.
pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
<Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
}
/// Returns a node at the specified index.
///
/// # Errors
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
/// the depth of this Merkle tree.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
if index.is_root() {
Err(MerkleError::DepthTooSmall(index.depth()))
} else if index.depth() > DEPTH {
Err(MerkleError::DepthTooBig(index.depth() as u64))
} else if index.depth() == DEPTH {
let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
Ok(leaf.into())
} else {
Ok(self.get_inner_node(index).hash())
}
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
pub fn open(&self, key: &LeafIndex<DEPTH>) -> ValuePath {
<Self as SparseMerkleTree<DEPTH>>::open(self, key)
}
/// Returns a boolean value indicating whether the SMT is empty.
pub fn is_empty(&self) -> bool {
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
self.root == Self::EMPTY_ROOT
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [SimpleSmt].
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
self.leaves.iter().map(|(i, w)| (*i, w))
}
/// Returns an iterator over the inner nodes of this [SimpleSmt].
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.inner_nodes.values().map(|e| InnerNodeInfo {
value: e.hash(),
left: e.left,
right: e.right,
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
/// [`EMPTY_WORD`].
///
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
}
/// Computes what changes are necessary to insert the specified key-value pairs into this
/// Merkle tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
/// Merkle tree, or [`drop()`] to discard them.
///
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
/// let pair = (LeafIndex::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
}
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
/// this tree and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`.
///
/// Returns the new root.
pub fn set_subtree<const SUBTREE_DEPTH: u8>(
&mut self,
subtree_insertion_index: u64,
subtree: SimpleSmt<SUBTREE_DEPTH>,
) -> Result<RpoDigest, MerkleError> {
if SUBTREE_DEPTH > DEPTH {
return Err(MerkleError::SubtreeDepthExceedsDepth {
subtree_depth: SUBTREE_DEPTH,
tree_depth: DEPTH,
});
}
// Verify that `subtree_insertion_index` is valid.
let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
let subtree_root_index =
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
// add leaves
// --------------
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
// valid as they are. However, consider what happens when we insert at
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
// you can see it as there's a full subtree sitting on its left. In general, for
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
// to insert, so we need to adjust all its leaves by `i * 2^d`.
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
self.leaves.insert(new_leaf_idx, *leaf_value);
}
// add subtree's branch nodes (which includes the root)
// --------------
for (branch_idx, branch_node) in subtree.inner_nodes {
let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
+ branch_idx.value();
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
};
self.inner_nodes.insert(new_branch_idx, branch_node);
}
// recompute nodes starting from subtree root
// --------------
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
Ok(self.root)
}
}
impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
type Key = LeafIndex<DEPTH>;
type Value = Word;
type Leaf = Word;
type Opening = ValuePath;
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn from_raw_parts(
inner_nodes: InnerNodes,
leaves: Leaves,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(root_node.hash(), root);
}
Ok(Self { root, inner_nodes, leaves })
}
fn root(&self) -> RpoDigest {
self.root
}
fn set_root(&mut self, root: RpoDigest) {
self.root = root;
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes
.get(&index)
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node)
}
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.remove(&index)
}
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
if value == Self::EMPTY_VALUE {
self.leaves.remove(&key.value())
} else {
self.leaves.insert(key.value(), value)
}
}
fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
self.get_leaf(key)
}
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
let leaf_pos = key.value();
match self.leaves.get(&leaf_pos) {
Some(word) => *word,
None => Self::EMPTY_VALUE,
}
}
fn hash_leaf(leaf: &Word) -> RpoDigest {
// `SimpleSmt` takes the leaf value itself as the hash
leaf.into()
}
fn construct_prospective_leaf(
&self,
_existing_leaf: Word,
_key: &LeafIndex<DEPTH>,
value: &Word,
) -> Word {
*value
}
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
*key
}
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
(path, leaf).into()
}
fn pairs_to_leaf(mut pairs: Vec<(LeafIndex<DEPTH>, Word)>) -> Word {
// SimpleSmt can't have more than one value per key.
assert_eq!(pairs.len(), 1);
let (_key, value) = pairs.pop().unwrap();
value
}
}

View file

@ -1,478 +0,0 @@
use alloc::vec::Vec;
use assert_matches::assert_matches;
use super::{
super::{MerkleError, RpoDigest, SimpleSmt},
NodeIndex,
};
use crate::{
hash::rpo::Rpo256,
merkle::{
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
InnerNodeInfo, LeafIndex, MerkleTree,
},
Word, EMPTY_WORD,
};
// TEST DATA
// ================================================================================================
const KEYS4: [u64; 4] = [0, 1, 2, 3];
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
const VALUES8: [RpoDigest; 8] = [
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
int_to_node(8),
];
const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
// TESTS
// ================================================================================================
#[test]
fn build_empty_tree() {
// tree of depth 3
let smt = SimpleSmt::<3>::new().unwrap();
let mt = MerkleTree::new(ZERO_VALUES8).unwrap();
assert_eq!(mt.root(), smt.root());
}
#[test]
fn build_sparse_tree() {
const DEPTH: u8 = 3;
let mut smt = SimpleSmt::<DEPTH>::new().unwrap();
let mut values = ZERO_VALUES8.to_vec();
assert_eq!(smt.num_leaves(), 0);
// insert single value
let key = 6;
let new_node = int_to_leaf(7);
values[key as usize] = new_node;
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
let mt2 = MerkleTree::new(values.clone()).unwrap();
assert_eq!(mt2.root(), smt.root());
assert_eq!(
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
smt.open(&LeafIndex::<3>::new(6).unwrap()).path
);
assert_eq!(old_value, EMPTY_WORD);
assert_eq!(smt.num_leaves(), 1);
// insert second value at distinct leaf branch
let key = 2;
let new_node = int_to_leaf(3);
values[key as usize] = new_node;
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
let mt3 = MerkleTree::new(values).unwrap();
assert_eq!(mt3.root(), smt.root());
assert_eq!(
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
smt.open(&LeafIndex::<3>::new(2).unwrap()).path
);
assert_eq!(old_value, EMPTY_WORD);
assert_eq!(smt.num_leaves(), 2);
}
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
#[test]
fn build_contiguous_tree() {
let tree_with_leaves =
SimpleSmt::<2>::with_leaves([0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4)))
.unwrap();
let tree_with_contiguous_leaves =
SimpleSmt::<2>::with_contiguous_leaves(digests_to_words(&VALUES4)).unwrap();
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
}
#[test]
fn test_depth2_tree() {
let tree =
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
// check internal structure
let (root, node2, node3) = compute_internal_nodes();
assert_eq!(root, tree.root());
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
// check get_node()
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
// check get_path(): depth 2
assert_eq!(vec![VALUES4[1], node3], *tree.open(&LeafIndex::<2>::new(0).unwrap()).path);
assert_eq!(vec![VALUES4[0], node3], *tree.open(&LeafIndex::<2>::new(1).unwrap()).path);
assert_eq!(vec![VALUES4[3], node2], *tree.open(&LeafIndex::<2>::new(2).unwrap()).path);
assert_eq!(vec![VALUES4[2], node2], *tree.open(&LeafIndex::<2>::new(3).unwrap()).path);
}
#[test]
fn test_inner_node_iterator() -> Result<(), MerkleError> {
let tree =
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
// check depth 2
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
// get parent nodes
let root = tree.root();
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
let mut nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let mut expected = [
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
];
nodes.sort();
expected.sort();
assert_eq!(nodes, expected);
Ok(())
}
#[test]
fn test_insert() {
const DEPTH: u8 = 3;
let mut tree =
SimpleSmt::<DEPTH>::with_leaves(KEYS8.into_iter().zip(digests_to_words(&VALUES8))).unwrap();
assert_eq!(tree.num_leaves(), 8);
// update one value
let key = 3;
let new_node = int_to_leaf(9);
let mut expected_values = digests_to_words(&VALUES8);
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
assert_eq!(tree.num_leaves(), 8);
// update another value
let key = 6;
let new_node = int_to_leaf(10);
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
assert_eq!(tree.num_leaves(), 8);
// set a leaf to empty value
let key = 5;
let new_node = EMPTY_WORD;
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
assert_eq!(tree.num_leaves(), 7);
}
#[test]
fn small_tree_opening_is_consistent() {
// ____k____
// / \
// _i_ _j_
// / \ / \
// e f g h
// / \ / \ / \ / \
// a b 0 0 c 0 0 d
let z = EMPTY_WORD;
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
let e = Rpo256::merge(&[a.into(), b.into()]);
let f = Rpo256::merge(&[z.into(), z.into()]);
let g = Rpo256::merge(&[c.into(), z.into()]);
let h = Rpo256::merge(&[z.into(), d.into()]);
let i = Rpo256::merge(&[e, f]);
let j = Rpo256::merge(&[g, h]);
let k = Rpo256::merge(&[i, j]);
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
let tree = SimpleSmt::<3>::with_leaves(entries).unwrap();
assert_eq!(tree.root(), k);
let cases: Vec<(u64, Vec<RpoDigest>)> = vec![
(0, vec![b.into(), f, j]),
(1, vec![a.into(), f, j]),
(4, vec![z.into(), h, i]),
(7, vec![z.into(), g, i]),
];
for (key, path) in cases {
let opening = tree.open(&LeafIndex::<3>::new(key).unwrap());
assert_eq!(path, *opening.path);
}
}
#[test]
fn test_simplesmt_fail_on_duplicates() {
let values = [
// same key, same value
(int_to_leaf(1), int_to_leaf(1)),
// same key, different values
(int_to_leaf(1), int_to_leaf(2)),
// same key, set to zero
(EMPTY_WORD, int_to_leaf(1)),
// same key, re-set to zero
(int_to_leaf(1), EMPTY_WORD),
// same key, set to zero twice
(EMPTY_WORD, EMPTY_WORD),
];
for (first, second) in values.iter() {
// consecutive
let entries = [(1, *first), (1, *second)];
let smt = SimpleSmt::<64>::with_leaves(entries);
assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
// not consecutive
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
let smt = SimpleSmt::<64>::with_leaves(entries);
assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
}
}
#[test]
fn with_no_duplicates_empty_node() {
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
let smt = SimpleSmt::<64>::with_leaves(entries);
assert!(smt.is_ok());
}
#[test]
fn test_simplesmt_with_leaves_nonexisting_leaf() {
// TESTING WITH EMPTY WORD
// --------------------------------------------------------------------------------------------
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
let leaves = [(2, EMPTY_WORD)];
let result = SimpleSmt::<1>::with_leaves(leaves);
assert!(result.is_err());
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
let leaves = [(4, EMPTY_WORD)];
let result = SimpleSmt::<2>::with_leaves(leaves);
assert!(result.is_err());
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
let leaves = [(8, EMPTY_WORD)];
let result = SimpleSmt::<3>::with_leaves(leaves);
assert!(result.is_err());
// TESTING WITH A VALUE
// --------------------------------------------------------------------------------------------
let value = int_to_node(1);
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
let leaves = [(2, *value)];
let result = SimpleSmt::<1>::with_leaves(leaves);
assert!(result.is_err());
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
let leaves = [(4, *value)];
let result = SimpleSmt::<2>::with_leaves(leaves);
assert!(result.is_err());
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
let leaves = [(8, *value)];
let result = SimpleSmt::<3>::with_leaves(leaves);
assert!(result.is_err());
}
#[test]
fn test_simplesmt_set_subtree() {
// Final Tree:
//
// ____k____
// / \
// _i_ _j_
// / \ / \
// e f g h
// / \ / \ / \ / \
// a b 0 0 c 0 0 d
let z = EMPTY_WORD;
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
let e = Rpo256::merge(&[a.into(), b.into()]);
let f = Rpo256::merge(&[z.into(), z.into()]);
let g = Rpo256::merge(&[c.into(), z.into()]);
let h = Rpo256::merge(&[z.into(), d.into()]);
let i = Rpo256::merge(&[e, f]);
let j = Rpo256::merge(&[g, h]);
let k = Rpo256::merge(&[i, j]);
// subtree:
// g
// / \
// c 0
let subtree = {
let entries = vec![(0, c)];
SimpleSmt::<1>::with_leaves(entries).unwrap()
};
// insert subtree
const TREE_DEPTH: u8 = 3;
let tree = {
let entries = vec![(0, a), (1, b), (7, d)];
let mut tree = SimpleSmt::<TREE_DEPTH>::with_leaves(entries).unwrap();
tree.set_subtree(2, subtree).unwrap();
tree
};
assert_eq!(tree.root(), k);
assert_eq!(tree.get_leaf(&LeafIndex::<TREE_DEPTH>::new(4).unwrap()), c);
assert_eq!(tree.get_inner_node(NodeIndex::new_unchecked(2, 2)).hash(), g);
}
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
#[test]
fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
// Final Tree:
//
// ____k____
// / \
// _i_ _j_
// / \ / \
// e f g h
// / \ / \ / \ / \
// a b 0 0 c 0 0 d
let z = EMPTY_WORD;
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
// subtree:
// g
// / \
// c 0
let subtree = {
let entries = vec![(0, c)];
SimpleSmt::<1>::with_leaves(entries).unwrap()
};
let mut tree = {
let entries = vec![(0, a), (1, b), (7, d)];
SimpleSmt::<3>::with_leaves(entries).unwrap()
};
let tree_root_before_insertion = tree.root();
// insert subtree
assert!(tree.set_subtree(500, subtree).is_err());
assert_eq!(tree.root(), tree_root_before_insertion);
}
/// We insert an empty subtree that has the same depth as the original tree
#[test]
fn test_simplesmt_set_subtree_entire_tree() {
// Initial Tree:
//
// ____k____
// / \
// _i_ _j_
// / \ / \
// e f g h
// / \ / \ / \ / \
// a b 0 0 c 0 0 d
let z = EMPTY_WORD;
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
// subtree: E3
const DEPTH: u8 = 3;
let subtree = { SimpleSmt::<DEPTH>::with_leaves(Vec::new()).unwrap() };
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
// insert subtree
let mut tree = {
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
SimpleSmt::<3>::with_leaves(entries).unwrap()
};
tree.set_subtree(0, subtree).unwrap();
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
}
/// Tests that `EMPTY_ROOT` constant generated in the `SimpleSmt` equals to the root of the empty
/// tree of depth 64
#[test]
fn test_simplesmt_check_empty_root_constant() {
// get the root of the empty tree of depth 64
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
assert_eq!(empty_root_64_depth, SimpleSmt::<64>::EMPTY_ROOT);
// get the root of the empty tree of depth 32
let empty_root_32_depth = EmptySubtreeRoots::empty_hashes(32)[0];
assert_eq!(empty_root_32_depth, SimpleSmt::<32>::EMPTY_ROOT);
// get the root of the empty tree of depth 0
let empty_root_1_depth = EmptySubtreeRoots::empty_hashes(1)[0];
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
let node2 = Rpo256::merge(&[VALUES4[0], VALUES4[1]]);
let node3 = Rpo256::merge(&[VALUES4[2], VALUES4[3]]);
let root = Rpo256::merge(&[node2, node3]);
(root, node2, node3)
}

View file

@ -1,627 +0,0 @@
use alloc::{collections::BTreeMap, vec::Vec};
use core::borrow::Borrow;
use super::{
mmr::Mmr, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTree, NodeIndex,
PartialMerkleTree, RootPath, Rpo256, RpoDigest, SimpleSmt, Smt, ValuePath,
};
use crate::utils::{
collections::{KvMap, RecordingMap},
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
};
#[cfg(test)]
mod tests;
// MERKLE STORE
// ================================================================================================
/// A default [MerkleStore] which uses a simple [BTreeMap] as the backing storage.
pub type DefaultMerkleStore = MerkleStore<BTreeMap<RpoDigest, StoreNode>>;
/// A [MerkleStore] with recording capabilities which uses [RecordingMap] as the backing storage.
pub type RecordingMerkleStore = MerkleStore<RecordingMap<RpoDigest, StoreNode>>;
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct StoreNode {
left: RpoDigest,
right: RpoDigest,
}
/// An in-memory data store for Merkelized data.
///
/// This is a in memory data store for Merkle trees, this store allows all the nodes of multiple
/// trees to live as long as necessary and without duplication, this allows the implementation of
/// space efficient persistent data structures.
///
/// Example usage:
///
/// ```rust
/// # use miden_crypto::{ZERO, Felt, Word};
/// # use miden_crypto::merkle::{NodeIndex, MerkleStore, MerkleTree};
/// # use miden_crypto::hash::rpo::Rpo256;
/// # const fn int_to_node(value: u64) -> Word {
/// # [Felt::new(value), ZERO, ZERO, ZERO]
/// # }
/// # let A = int_to_node(1);
/// # let B = int_to_node(2);
/// # let C = int_to_node(3);
/// # let D = int_to_node(4);
/// # let E = int_to_node(5);
/// # let F = int_to_node(6);
/// # let G = int_to_node(7);
/// # let H0 = int_to_node(8);
/// # let H1 = int_to_node(9);
/// # let T0 = MerkleTree::new([A, B, C, D, E, F, G, H0].to_vec()).expect("even number of leaves provided");
/// # let T1 = MerkleTree::new([A, B, C, D, E, F, G, H1].to_vec()).expect("even number of leaves provided");
/// # let ROOT0 = T0.root();
/// # let ROOT1 = T1.root();
/// let mut store: MerkleStore = MerkleStore::new();
///
/// // the store is initialized with the SMT empty nodes
/// assert_eq!(store.num_internal_nodes(), 255);
///
/// let tree1 = MerkleTree::new(vec![A, B, C, D, E, F, G, H0]).unwrap();
/// let tree2 = MerkleTree::new(vec![A, B, C, D, E, F, G, H1]).unwrap();
///
/// // populates the store with two merkle trees, common nodes are shared
/// store.extend(tree1.inner_nodes());
/// store.extend(tree2.inner_nodes());
///
/// // every leaf except the last are the same
/// for i in 0..7 {
/// let idx0 = NodeIndex::new(3, i).unwrap();
/// let d0 = store.get_node(ROOT0, idx0).unwrap();
/// let idx1 = NodeIndex::new(3, i).unwrap();
/// let d1 = store.get_node(ROOT1, idx1).unwrap();
/// assert_eq!(d0, d1, "Both trees have the same leaf at pos {i}");
/// }
///
/// // The leafs A-B-C-D are the same for both trees, so are their 2 immediate parents
/// for i in 0..4 {
/// let idx0 = NodeIndex::new(3, i).unwrap();
/// let d0 = store.get_path(ROOT0, idx0).unwrap();
/// let idx1 = NodeIndex::new(3, i).unwrap();
/// let d1 = store.get_path(ROOT1, idx1).unwrap();
/// assert_eq!(d0.path[0..2], d1.path[0..2], "Both sub-trees are equal up to two levels");
/// }
///
/// // Common internal nodes are shared, the two added trees have a total of 30, but the store has
/// // only 10 new entries, corresponding to the 10 unique internal nodes of these trees.
/// assert_eq!(store.num_internal_nodes() - 255, 10);
/// ```
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleStore<T: KvMap<RpoDigest, StoreNode> = BTreeMap<RpoDigest, StoreNode>> {
nodes: T,
}
impl<T: KvMap<RpoDigest, StoreNode>> Default for MerkleStore<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Creates an empty `MerkleStore` instance.
pub fn new() -> MerkleStore<T> {
// pre-populate the store with the empty hashes
let nodes = empty_hashes().into_iter().collect();
MerkleStore { nodes }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Return a count of the non-leaf nodes in the store.
pub fn num_internal_nodes(&self) -> usize {
self.nodes.len()
}
/// Returns the node at `index` rooted on the tree `root`.
///
/// # Errors
/// This method can return the following errors:
/// - `RootNotInStore` if the `root` is not present in the store.
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
/// store.
pub fn get_node(&self, root: RpoDigest, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
let mut hash = root;
// corner case: check the root is in the store when called with index `NodeIndex::root()`
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
for i in (0..index.depth()).rev() {
let node = self
.nodes
.get(&hash)
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
let bit = (index.value() >> i) & 1;
hash = if bit == 0 { node.left } else { node.right }
}
Ok(hash)
}
/// Returns the node at the specified `index` and its opening to the `root`.
///
/// The path starts at the sibling of the target leaf.
///
/// # Errors
/// This method can return the following errors:
/// - `RootNotInStore` if the `root` is not present in the store.
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
/// store.
pub fn get_path(&self, root: RpoDigest, index: NodeIndex) -> Result<ValuePath, MerkleError> {
let mut hash = root;
let mut path = Vec::with_capacity(index.depth().into());
// corner case: check the root is in the store when called with index `NodeIndex::root()`
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
for i in (0..index.depth()).rev() {
let node = self
.nodes
.get(&hash)
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
let bit = (index.value() >> i) & 1;
hash = if bit == 0 {
path.push(node.right);
node.left
} else {
path.push(node.left);
node.right
}
}
// the path is computed from root to leaf, so it must be reversed
path.reverse();
Ok(ValuePath::new(hash, MerklePath::new(path)))
}
// LEAF TRAVERSAL
// --------------------------------------------------------------------------------------------
/// Returns the depth of the first leaf or an empty node encountered while traversing the tree
/// from the specified root down according to the provided index.
///
/// The `tree_depth` parameter specifies the depth of the tree rooted at `root`. The
/// maximum value the argument accepts is [u64::BITS].
///
/// # Errors
/// Will return an error if:
/// - The provided root is not found.
/// - The provided `tree_depth` is greater than 64.
/// - The provided `index` is not valid for a depth equivalent to `tree_depth`.
/// - No leaf or an empty node was found while traversing the tree down to `tree_depth`.
pub fn get_leaf_depth(
&self,
root: RpoDigest,
tree_depth: u8,
index: u64,
) -> Result<u8, MerkleError> {
// validate depth and index
if tree_depth > 64 {
return Err(MerkleError::DepthTooBig(tree_depth as u64));
}
NodeIndex::new(tree_depth, index)?;
// check if the root exists, providing the proper error report if it doesn't
let empty = EmptySubtreeRoots::empty_hashes(tree_depth);
let mut hash = root;
if !self.nodes.contains_key(&hash) {
return Err(MerkleError::RootNotInStore(hash));
}
// we traverse from root to leaf, so the path is reversed
let mut path = (index << (64 - tree_depth)).reverse_bits();
// iterate every depth and reconstruct the path from root to leaf
for depth in 0..=tree_depth {
// we short-circuit if an empty node has been found
if hash == empty[depth as usize] {
return Ok(depth);
}
// fetch the children pair, mapped by its parent hash
let children = match self.nodes.get(&hash) {
Some(node) => node,
None => return Ok(depth),
};
// traverse down
hash = if path & 1 == 0 { children.left } else { children.right };
path >>= 1;
}
// return an error because we exhausted the index but didn't find either a leaf or an
// empty node
Err(MerkleError::DepthTooBig(tree_depth as u64 + 1))
}
/// Returns index and value of a leaf node which is the only leaf node in a subtree defined by
/// the provided root. If the subtree contains zero or more than one leaf nodes None is
/// returned.
///
/// The `tree_depth` parameter specifies the depth of the parent tree such that `root` is
/// located in this tree at `root_index`. The maximum value the argument accepts is
/// [u64::BITS].
///
/// # Errors
/// Will return an error if:
/// - The provided root is not found.
/// - The provided `tree_depth` is greater than 64.
/// - The provided `root_index` has depth greater than `tree_depth`.
/// - A lone node at depth `tree_depth` is not a leaf node.
pub fn find_lone_leaf(
&self,
root: RpoDigest,
root_index: NodeIndex,
tree_depth: u8,
) -> Result<Option<(NodeIndex, RpoDigest)>, MerkleError> {
// we set max depth at u64::BITS as this is the largest meaningful value for a 64-bit index
const MAX_DEPTH: u8 = u64::BITS as u8;
if tree_depth > MAX_DEPTH {
return Err(MerkleError::DepthTooBig(tree_depth as u64));
}
let empty = EmptySubtreeRoots::empty_hashes(MAX_DEPTH);
let mut node = root;
if !self.nodes.contains_key(&node) {
return Err(MerkleError::RootNotInStore(node));
}
let mut index = root_index;
if index.depth() > tree_depth {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
}
// traverse down following the path of single non-empty nodes; this works because if a
// node has two empty children it cannot contain a lone leaf. similarly if a node has
// two non-empty children it must contain at least two leaves.
for depth in index.depth()..tree_depth {
// if the node is a leaf, return; otherwise, examine the node's children
let children = match self.nodes.get(&node) {
Some(node) => node,
None => return Ok(Some((index, node))),
};
let empty_node = empty[depth as usize + 1];
node = if children.left != empty_node && children.right == empty_node {
index = index.left_child();
children.left
} else if children.left == empty_node && children.right != empty_node {
index = index.right_child();
children.right
} else {
return Ok(None);
};
}
// if we are here, we got to `tree_depth`; thus, either the current node is a leaf node,
// and so we return it, or it is an internal node, and then we return an error
if self.nodes.contains_key(&node) {
Err(MerkleError::DepthTooBig(tree_depth as u64 + 1))
} else {
Ok(Some((index, node)))
}
}
// DATA EXTRACTORS
// --------------------------------------------------------------------------------------------
/// Returns a subset of this Merkle store such that the returned Merkle store contains all
/// nodes which are descendants of the specified roots.
///
/// The roots for which no descendants exist in this Merkle store are ignored.
pub fn subset<I, R>(&self, roots: I) -> MerkleStore<T>
where
I: Iterator<Item = R>,
R: Borrow<RpoDigest>,
{
let mut store = MerkleStore::new();
for root in roots {
let root = *root.borrow();
store.clone_tree_from(root, self);
}
store
}
/// Iterator over the inner nodes of the [MerkleStore].
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.nodes
.iter()
.map(|(r, n)| InnerNodeInfo { value: *r, left: n.left, right: n.right })
}
/// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root`
/// and `max_depth`.
pub fn non_empty_leaves(
&self,
root: RpoDigest,
max_depth: u8,
) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth);
let mut stack = Vec::new();
stack.push((NodeIndex::new_unchecked(0, 0), root));
core::iter::from_fn(move || {
while let Some((index, node_hash)) = stack.pop() {
// if we are at the max depth then we have reached a leaf
if index.depth() == max_depth {
return Some((index, node_hash));
}
// fetch the nodes children and push them onto the stack if they are not the roots
// of empty subtrees
if let Some(node) = self.nodes.get(&node_hash) {
if !empty_roots.contains(&node.left) {
stack.push((index.left_child(), node.left));
}
if !empty_roots.contains(&node.right) {
stack.push((index.right_child(), node.right));
}
// if the node is not in the store assume it is a leaf
} else {
return Some((index, node_hash));
}
}
None
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Adds all the nodes of a Merkle path represented by `path`, opening to `node`. Returns the
/// new root.
///
/// This will compute the sibling elements determined by the Merkle `path` and `node`, and
/// include all the nodes into the store.
pub fn add_merkle_path(
&mut self,
index: u64,
node: RpoDigest,
path: MerklePath,
) -> Result<RpoDigest, MerkleError> {
let root = path.inner_nodes(index, node)?.fold(RpoDigest::default(), |_, node| {
let value: RpoDigest = node.value;
let left: RpoDigest = node.left;
let right: RpoDigest = node.right;
debug_assert_eq!(Rpo256::merge(&[left, right]), value);
self.nodes.insert(value, StoreNode { left, right });
node.value
});
Ok(root)
}
/// Adds all the nodes of multiple Merkle paths into the store.
///
/// This will compute the sibling elements for each Merkle `path` and include all the nodes
/// into the store.
///
/// For further reference, check [MerkleStore::add_merkle_path].
pub fn add_merkle_paths<I>(&mut self, paths: I) -> Result<(), MerkleError>
where
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
{
for (index_value, node, path) in paths.into_iter() {
self.add_merkle_path(index_value, node, path)?;
}
Ok(())
}
/// Sets a node to `value`.
///
/// # Errors
/// This method can return the following errors:
/// - `RootNotInStore` if the `root` is not present in the store.
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
/// store.
pub fn set_node(
&mut self,
mut root: RpoDigest,
index: NodeIndex,
value: RpoDigest,
) -> Result<RootPath, MerkleError> {
let node = value;
let ValuePath { value, path } = self.get_path(root, index)?;
// performs the update only if the node value differs from the opening
if node != value {
root = self.add_merkle_path(index.value(), node, path.clone())?;
}
Ok(RootPath { root, path })
}
/// Merges two elements and adds the resulting node into the store.
///
/// Merges arbitrary values. They may be leafs, nodes, or a mixture of both.
pub fn merge_roots(
&mut self,
left_root: RpoDigest,
right_root: RpoDigest,
) -> Result<RpoDigest, MerkleError> {
let parent = Rpo256::merge(&[left_root, right_root]);
self.nodes.insert(parent, StoreNode { left: left_root, right: right_root });
Ok(parent)
}
// DESTRUCTURING
// --------------------------------------------------------------------------------------------
/// Returns the inner storage of this MerkleStore while consuming `self`.
pub fn into_inner(self) -> T {
self.nodes
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Recursively clones a tree with the specified root from the specified source into self.
///
/// If the source store does not contain a tree with the specified root, this is a noop.
fn clone_tree_from(&mut self, root: RpoDigest, source: &Self) {
// process the node only if it is in the source
if let Some(node) = source.nodes.get(&root) {
// if the node has already been inserted, no need to process it further as all of its
// descendants should be already cloned from the source store
if self.nodes.insert(root, *node).is_none() {
self.clone_tree_from(node.left, source);
self.clone_tree_from(node.right, source);
}
}
}
}
// CONVERSIONS
// ================================================================================================
impl<T: KvMap<RpoDigest, StoreNode>> From<&MerkleTree> for MerkleStore<T> {
fn from(value: &MerkleTree) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>, const DEPTH: u8> From<&SimpleSmt<DEPTH>> for MerkleStore<T> {
fn from(value: &SimpleSmt<DEPTH>) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<&Smt> for MerkleStore<T> {
fn from(value: &Smt) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<&Mmr> for MerkleStore<T> {
fn from(value: &Mmr) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<&PartialMerkleTree> for MerkleStore<T> {
fn from(value: &PartialMerkleTree) -> Self {
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> From<T> for MerkleStore<T> {
fn from(values: T) -> Self {
let nodes = values.into_iter().chain(empty_hashes()).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<InnerNodeInfo> for MerkleStore<T> {
fn from_iter<I: IntoIterator<Item = InnerNodeInfo>>(iter: I) -> Self {
let nodes = combine_nodes_with_empty_hashes(iter).collect();
Self { nodes }
}
}
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<(RpoDigest, StoreNode)> for MerkleStore<T> {
fn from_iter<I: IntoIterator<Item = (RpoDigest, StoreNode)>>(iter: I) -> Self {
let nodes = iter.into_iter().chain(empty_hashes()).collect();
Self { nodes }
}
}
// ITERATORS
// ================================================================================================
impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
fn extend<I: IntoIterator<Item = InnerNodeInfo>>(&mut self, iter: I) {
self.nodes.extend(
iter.into_iter()
.map(|info| (info.value, StoreNode { left: info.left, right: info.right })),
);
}
}
// SERIALIZATION
// ================================================================================================
impl Serializable for StoreNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target);
self.right.write_into(target);
}
}
impl Deserializable for StoreNode {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let left = RpoDigest::read_from(source)?;
let right = RpoDigest::read_from(source)?;
Ok(StoreNode { left, right })
}
}
impl<T: KvMap<RpoDigest, StoreNode>> Serializable for MerkleStore<T> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u64(self.nodes.len() as u64);
for (k, v) in self.nodes.iter() {
k.write_into(target);
v.write_into(target);
}
}
}
impl<T: KvMap<RpoDigest, StoreNode>> Deserializable for MerkleStore<T> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let len = source.read_u64()?;
let mut nodes: Vec<(RpoDigest, StoreNode)> = Vec::with_capacity(len as usize);
for _ in 0..len {
let key = RpoDigest::read_from(source)?;
let value = StoreNode::read_from(source)?;
nodes.push((key, value));
}
Ok(nodes.into_iter().collect())
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Creates empty hashes for all the subtrees of a tree with a max depth of 255.
fn empty_hashes() -> impl IntoIterator<Item = (RpoDigest, StoreNode)> {
let subtrees = EmptySubtreeRoots::empty_hashes(255);
subtrees
.iter()
.rev()
.copied()
.zip(subtrees.iter().rev().skip(1).copied())
.map(|(child, parent)| (parent, StoreNode { left: child, right: child }))
}
/// Consumes an iterator of [InnerNodeInfo] and returns an iterator of `(value, node)` tuples
/// which includes the nodes associate with roots of empty subtrees up to a depth of 255.
fn combine_nodes_with_empty_hashes(
nodes: impl IntoIterator<Item = InnerNodeInfo>,
) -> impl Iterator<Item = (RpoDigest, StoreNode)> {
nodes
.into_iter()
.map(|info| (info.value, StoreNode { left: info.left, right: info.right }))
.chain(empty_hashes())
}

View file

@ -1,933 +0,0 @@
use assert_matches::assert_matches;
use seq_macro::seq;
#[cfg(feature = "std")]
use {
super::{Deserializable, Serializable},
alloc::boxed::Box,
std::error::Error,
};
use super::{
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
};
use crate::{
merkle::{
digests_to_words, int_to_leaf, int_to_node, LeafIndex, MerkleTree, SimpleSmt, SMT_MAX_DEPTH,
},
Felt, Word, ONE, WORD_SIZE, ZERO,
};
// TEST DATA
// ================================================================================================
const KEYS4: [u64; 4] = [0, 1, 2, 3];
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
const VALUES8: [RpoDigest; 8] = [
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
int_to_node(8),
];
// TESTS
// ================================================================================================
#[test]
fn test_root_not_in_store() -> Result<(), MerkleError> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let store = MerkleStore::from(&mtree);
assert_matches!(
store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
"Leaf 0 is not a root"
);
assert_matches!(
store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
"Leaf 0 is not a root"
);
Ok(())
}
#[test]
fn test_merkle_tree() -> Result<(), MerkleError> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let store = MerkleStore::from(&mtree);
// STORE LEAVES ARE CORRECT -------------------------------------------------------------------
// checks the leaves in the store corresponds to the expected values
assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
VALUES4[0],
"node 0 must be in the tree"
);
assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
VALUES4[1],
"node 1 must be in the tree"
);
assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
VALUES4[2],
"node 2 must be in the tree"
);
assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
VALUES4[3],
"node 3 must be in the tree"
);
// STORE LEAVES MATCH TREE --------------------------------------------------------------------
// sanity check the values returned by the store and the tree
assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 0)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
"node 0 must be the same for both MerkleTree and MerkleStore"
);
assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 1)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
"node 1 must be the same for both MerkleTree and MerkleStore"
);
assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 2)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
"node 2 must be the same for both MerkleTree and MerkleStore"
);
assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 3)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
"node 3 must be the same for both MerkleTree and MerkleStore"
);
// STORE MERKLE PATH MATCHES ==============================================================
// assert the merkle path returned by the store is the same as the one in the tree
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap();
assert_eq!(
VALUES4[0], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 0)).unwrap(),
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap();
assert_eq!(
VALUES4[1], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 1)).unwrap(),
result.path,
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap();
assert_eq!(
VALUES4[2], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 2)).unwrap(),
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap();
assert_eq!(
VALUES4[3], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 3)).unwrap(),
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
Ok(())
}
#[test]
fn test_empty_roots() {
let store = MerkleStore::default();
let mut root = RpoDigest::default();
for depth in 0..255 {
root = Rpo256::merge(&[root; 2]);
assert!(
store.get_node(root, NodeIndex::make(0, 0)).is_ok(),
"The root of the empty tree of depth {depth} must be registered"
);
}
}
#[test]
fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
let store = MerkleStore::default();
// Starts at 1 because leafs are not included in the store.
// Ends at 64 because it is not possible to represent an index of a depth greater than 64,
// because a u64 is used to index the leaf.
seq!(DEPTH in 1_u8..64_u8 {
let smt = SimpleSmt::<DEPTH>::new()?;
let index = NodeIndex::make(DEPTH, 0);
let store_path = store.get_path(smt.root(), index)?;
let smt_path = smt.open(&LeafIndex::<DEPTH>::new(0)?).path;
assert_eq!(
store_path.value,
RpoDigest::default(),
"the leaf of an empty tree is always ZERO"
);
assert_eq!(
store_path.path, smt_path,
"the returned merkle path does not match the computed values"
);
assert_eq!(
store_path.path.compute_root(DEPTH.into(), RpoDigest::default()).unwrap(),
smt.root(),
"computed root from the path must match the empty tree root"
);
});
Ok(())
}
#[test]
fn test_get_invalid_node() {
let mtree =
MerkleTree::new(digests_to_words(&VALUES4)).expect("creating a merkle tree must work");
let store = MerkleStore::from(&mtree);
let _ = store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3));
}
#[test]
fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
let keys2: [u64; 2] = [0, 1];
let leaves2: [Word; 2] = [int_to_leaf(1), int_to_leaf(2)];
let smt = SimpleSmt::<1>::with_leaves(keys2.into_iter().zip(leaves2)).unwrap();
let store = MerkleStore::from(&smt);
let idx = NodeIndex::make(1, 0);
assert_eq!(smt.get_node(idx).unwrap(), leaves2[0].into());
assert_eq!(store.get_node(smt.root(), idx).unwrap(), smt.get_node(idx).unwrap());
let idx = NodeIndex::make(1, 1);
assert_eq!(smt.get_node(idx).unwrap(), leaves2[1].into());
assert_eq!(store.get_node(smt.root(), idx).unwrap(), smt.get_node(idx).unwrap());
Ok(())
}
#[test]
fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
let smt =
SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4)))
.unwrap();
let store = MerkleStore::from(&smt);
// STORE LEAVES ARE CORRECT ==============================================================
// checks the leaves in the store corresponds to the expected values
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
VALUES4[0],
"node 0 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
VALUES4[1],
"node 1 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
VALUES4[2],
"node 2 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
VALUES4[3],
"node 3 must be in the tree"
);
assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
RpoDigest::default(),
"unmodified node 4 must be ZERO"
);
// STORE LEAVES MATCH TREE ===============================================================
// sanity check the values returned by the store and the tree
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
);
assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
"node 4 must be the same for both SparseMerkleTree and MerkleStore"
);
// STORE MERKLE PATH MATCHES ==============================================================
// assert the merkle path returned by the store is the same as the one in the tree
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap();
assert_eq!(
VALUES4[0], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(0).unwrap()).path,
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap();
assert_eq!(
VALUES4[1], result.value,
"Value for merkle path at index 1 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(1).unwrap()).path,
result.path,
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap();
assert_eq!(
VALUES4[2], result.value,
"Value for merkle path at index 2 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(2).unwrap()).path,
result.path,
"merkle path for index 2 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap();
assert_eq!(
VALUES4[3], result.value,
"Value for merkle path at index 3 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(3).unwrap()).path,
result.path,
"merkle path for index 3 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap();
assert_eq!(
RpoDigest::default(),
result.value,
"Value for merkle path at index 4 must match leaf value"
);
assert_eq!(
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(4).unwrap()).path,
result.path,
"merkle path for index 4 must be the same for the MerkleTree and MerkleStore"
);
Ok(())
}
#[test]
fn test_add_merkle_paths() -> Result<(), MerkleError> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let i0 = 0;
let p0 = mtree.get_path(NodeIndex::make(2, i0)).unwrap();
let i1 = 1;
let p1 = mtree.get_path(NodeIndex::make(2, i1)).unwrap();
let i2 = 2;
let p2 = mtree.get_path(NodeIndex::make(2, i2)).unwrap();
let i3 = 3;
let p3 = mtree.get_path(NodeIndex::make(2, i3)).unwrap();
let paths = [
(i0, VALUES4[i0 as usize], p0),
(i1, VALUES4[i1 as usize], p1),
(i2, VALUES4[i2 as usize], p2),
(i3, VALUES4[i3 as usize], p3),
];
let mut store = MerkleStore::default();
store.add_merkle_paths(paths.clone()).expect("the valid paths must work");
let pmt = PartialMerkleTree::with_paths(paths).unwrap();
// STORE LEAVES ARE CORRECT ==============================================================
// checks the leaves in the store corresponds to the expected values
assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
VALUES4[0],
"node 0 must be in the pmt"
);
assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
VALUES4[1],
"node 1 must be in the pmt"
);
assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
VALUES4[2],
"node 2 must be in the pmt"
);
assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
VALUES4[3],
"node 3 must be in the pmt"
);
// STORE LEAVES MATCH PMT ================================================================
// sanity check the values returned by the store and the pmt
assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
"node 0 must be the same for both PartialMerkleTree and MerkleStore"
);
assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
"node 1 must be the same for both PartialMerkleTree and MerkleStore"
);
assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
"node 2 must be the same for both PartialMerkleTree and MerkleStore"
);
assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
);
// STORE MERKLE PATH MATCHES ==============================================================
// assert the merkle path returned by the store is the same as the one in the pmt
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap();
assert_eq!(
VALUES4[0], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap();
assert_eq!(
VALUES4[1], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
result.path,
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap();
assert_eq!(
VALUES4[2], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap();
assert_eq!(
VALUES4[3], result.value,
"Value for merkle path at index 0 must match leaf value"
);
assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
);
Ok(())
}
#[test]
fn wont_open_to_different_depth_root() {
let empty = EmptySubtreeRoots::empty_hashes(64);
let a = [ONE; 4];
let b = [Felt::new(2); 4];
// Compute the root for a different depth. We cherry-pick this specific depth to prevent a
// regression to a bug in the past that allowed the user to fetch a node at a depth lower than
// the inserted path of a Merkle tree.
let mut root = Rpo256::merge(&[a.into(), b.into()]);
for depth in (1..=63).rev() {
root = Rpo256::merge(&[root, empty[depth]]);
}
// For this example, the depth of the Merkle tree is 1, as we have only two leaves. Here we
// attempt to fetch a node on the maximum depth, and it should fail because the root shouldn't
// exist for the set.
let mtree = MerkleTree::new(vec![a, b]).unwrap();
let store = MerkleStore::from(&mtree);
let index = NodeIndex::root();
let err = store.get_node(root, index).err().unwrap();
assert_matches!(err, MerkleError::RootNotInStore(err_root) if err_root == root);
}
#[test]
fn store_path_opens_from_leaf() {
let a = [ONE; 4];
let b = [Felt::new(2); 4];
let c = [Felt::new(3); 4];
let d = [Felt::new(4); 4];
let e = [Felt::new(5); 4];
let f = [Felt::new(6); 4];
let g = [Felt::new(7); 4];
let h = [Felt::new(8); 4];
let i = Rpo256::merge(&[a.into(), b.into()]);
let j = Rpo256::merge(&[c.into(), d.into()]);
let k = Rpo256::merge(&[e.into(), f.into()]);
let l = Rpo256::merge(&[g.into(), h.into()]);
let m = Rpo256::merge(&[i, j]);
let n = Rpo256::merge(&[k, l]);
let root = Rpo256::merge(&[m, n]);
let mtree = MerkleTree::new(vec![a, b, c, d, e, f, g, h]).unwrap();
let store = MerkleStore::from(&mtree);
let path = store.get_path(root, NodeIndex::make(3, 1)).unwrap().path;
let expected = MerklePath::new([a.into(), j, n].to_vec());
assert_eq!(path, expected);
}
#[test]
fn test_set_node() -> Result<(), MerkleError> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let mut store = MerkleStore::from(&mtree);
let value = int_to_node(42);
let index = NodeIndex::make(mtree.depth(), 0);
let new_root = store.set_node(mtree.root(), index, value)?.root;
assert_eq!(store.get_node(new_root, index).unwrap(), value, "value must have changed");
Ok(())
}
#[test]
fn test_constructors() -> Result<(), MerkleError> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let store = MerkleStore::from(&mtree);
let depth = mtree.depth();
let leaves = 2u64.pow(depth.into());
for index in 0..leaves {
let index = NodeIndex::make(depth, index);
let value_path = store.get_path(mtree.root(), index)?;
assert_eq!(mtree.get_path(index)?, value_path.path);
}
const DEPTH: u8 = 32;
let smt =
SimpleSmt::<DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
let store = MerkleStore::from(&smt);
for key in KEYS4 {
let index = NodeIndex::make(DEPTH, key);
let value_path = store.get_path(smt.root(), index)?;
assert_eq!(smt.open(&LeafIndex::<DEPTH>::new(key).unwrap()).path, value_path.path);
}
let d = 2;
let paths = [
(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0)).unwrap()),
(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1)).unwrap()),
(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2)).unwrap()),
(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3)).unwrap()),
];
let mut store1 = MerkleStore::default();
store1.add_merkle_paths(paths.clone())?;
let mut store2 = MerkleStore::default();
store2.add_merkle_path(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0))?)?;
store2.add_merkle_path(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?;
store2.add_merkle_path(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?;
store2.add_merkle_path(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?;
let pmt = PartialMerkleTree::with_paths(paths).unwrap();
for key in [0, 1, 2, 3] {
let index = NodeIndex::make(d, key);
let value_path1 = store1.get_path(pmt.root(), index)?;
let value_path2 = store2.get_path(pmt.root(), index)?;
assert_eq!(value_path1, value_path2);
let index = NodeIndex::make(d, key);
assert_eq!(pmt.get_path(index)?, value_path1.path);
}
Ok(())
}
#[test]
fn node_path_should_be_truncated_by_midtier_insert() {
let key = 0b11010010_11001100_11001100_11001100_11001100_11001100_11001100_11001100_u64;
let mut store = MerkleStore::new();
let root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
// insert first node - works as expected
let depth = 64;
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
let index = NodeIndex::new(depth, key).unwrap();
let root = store.set_node(root, index, node).unwrap().root;
let result = store.get_node(root, index).unwrap();
let path = store.get_path(root, index).unwrap().path;
assert_eq!(node, result);
assert_eq!(path.depth(), depth);
assert!(path.verify(index.value(), result, &root).is_ok());
// flip the first bit of the key and insert the second node on a different depth
let key = key ^ (1 << 63);
let key = key >> 8;
let depth = 56;
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
let index = NodeIndex::new(depth, key).unwrap();
let root = store.set_node(root, index, node).unwrap().root;
let result = store.get_node(root, index).unwrap();
let path = store.get_path(root, index).unwrap().path;
assert_eq!(node, result);
assert_eq!(path.depth(), depth);
assert!(path.verify(index.value(), result, &root).is_ok());
// attempt to fetch a path of the second node to depth 64
// should fail because the previously inserted node will remove its sub-tree from the set
let key = key << 8;
let index = NodeIndex::new(64, key).unwrap();
assert!(store.get_node(root, index).is_err());
}
// LEAF TRAVERSAL
// ================================================================================================
#[test]
fn get_leaf_depth_works_depth_64() {
let mut store = MerkleStore::new();
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
let key = u64::MAX;
// this will create a rainbow tree and test all opening to depth 64
for d in 0..64 {
let k = key & (u64::MAX >> d);
let node = RpoDigest::from([Felt::new(k); WORD_SIZE]);
let index = NodeIndex::new(64, k).unwrap();
// assert the leaf doesn't exist before the insert. the returned depth should always
// increment with the paths count of the set, as they are intersecting one another up to
// the first bits of the used key.
assert_eq!(d, store.get_leaf_depth(root, 64, k).unwrap());
// insert and assert the correct depth
root = store.set_node(root, index, node).unwrap().root;
assert_eq!(64, store.get_leaf_depth(root, 64, k).unwrap());
}
}
#[test]
fn get_leaf_depth_works_with_incremental_depth() {
let mut store = MerkleStore::new();
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
// insert some path to the left of the root and assert it
let key = 0b01001011_10110110_00001101_01110100_00111011_10101101_00000100_01000001_u64;
assert_eq!(0, store.get_leaf_depth(root, 64, key).unwrap());
let depth = 64;
let index = NodeIndex::new(depth, key).unwrap();
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
root = store.set_node(root, index, node).unwrap().root;
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
// flip the key to the right of the root and insert some content on depth 16
let key = 0b11001011_10110110_00000000_00000000_00000000_00000000_00000000_00000000_u64;
assert_eq!(1, store.get_leaf_depth(root, 64, key).unwrap());
let depth = 16;
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
root = store.set_node(root, index, node).unwrap().root;
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
// attempt the sibling of the previous leaf
let key = 0b11001011_10110111_00000000_00000000_00000000_00000000_00000000_00000000_u64;
assert_eq!(16, store.get_leaf_depth(root, 64, key).unwrap());
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
root = store.set_node(root, index, node).unwrap().root;
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
// move down to the next depth and assert correct behavior
let key = 0b11001011_10110100_00000000_00000000_00000000_00000000_00000000_00000000_u64;
assert_eq!(15, store.get_leaf_depth(root, 64, key).unwrap());
let depth = 17;
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
root = store.set_node(root, index, node).unwrap().root;
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
}
#[test]
fn get_leaf_depth_works_with_depth_8() {
let mut store = MerkleStore::new();
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(8)[0];
// insert some random, 8 depth keys. `a` diverges from the first bit
let a = 0b01101001_u64;
let b = 0b10011001_u64;
let c = 0b10010110_u64;
let d = 0b11110110_u64;
for k in [a, b, c, d] {
let index = NodeIndex::new(8, k).unwrap();
let node = RpoDigest::from([Felt::new(k); WORD_SIZE]);
root = store.set_node(root, index, node).unwrap().root;
}
// assert all leaves returns the inserted depth
for k in [a, b, c, d] {
assert_eq!(8, store.get_leaf_depth(root, 8, k).unwrap());
}
// flip last bit of a and expect it to return the same depth, but for an empty node
assert_eq!(8, store.get_leaf_depth(root, 8, 0b01101000_u64).unwrap());
// flip fourth bit of a and expect an empty node on depth 4
assert_eq!(4, store.get_leaf_depth(root, 8, 0b01111001_u64).unwrap());
// flip third bit of a and expect an empty node on depth 3
assert_eq!(3, store.get_leaf_depth(root, 8, 0b01001001_u64).unwrap());
// flip second bit of a and expect an empty node on depth 2
assert_eq!(2, store.get_leaf_depth(root, 8, 0b00101001_u64).unwrap());
// flip fourth bit of c and expect an empty node on depth 4
assert_eq!(4, store.get_leaf_depth(root, 8, 0b10000110_u64).unwrap());
// flip second bit of d and expect an empty node on depth 3 as depth 2 conflicts with b and c
assert_eq!(3, store.get_leaf_depth(root, 8, 0b10110110_u64).unwrap());
// duplicate the tree on `a` and assert the depth is short-circuited by such sub-tree
let index = NodeIndex::new(8, a).unwrap();
root = store.set_node(root, index, root).unwrap().root;
assert_matches!(store.get_leaf_depth(root, 8, a).unwrap_err(), MerkleError::DepthTooBig(9));
}
#[test]
fn find_lone_leaf() {
let mut store = MerkleStore::new();
let empty = EmptySubtreeRoots::empty_hashes(64);
let mut root: RpoDigest = empty[0];
// insert a single leaf into the store at depth 64
let key_a = 0b01010101_10101010_00001111_01110100_00111011_10101101_00000100_01000001_u64;
let idx_a = NodeIndex::make(64, key_a);
let val_a = RpoDigest::from([ONE, ONE, ONE, ONE]);
root = store.set_node(root, idx_a, val_a).unwrap().root;
// for every ancestor of A, A should be a long leaf
for depth in 1..64 {
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
let parent = store.get_node(root, parent_index).unwrap();
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
assert_eq!(res, Some((idx_a, val_a)));
}
// insert another leaf into the store such that it has the same 8 bit prefix as A
let key_b = 0b01010101_01111010_00001111_01110100_00111011_10101101_00000100_01000001_u64;
let idx_b = NodeIndex::make(64, key_b);
let val_b = RpoDigest::from([ONE, ONE, ONE, ZERO]);
root = store.set_node(root, idx_b, val_b).unwrap().root;
// for any node which is common between A and B, find_lone_leaf() should return None as the
// node has two descendants
for depth in 1..9 {
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
let parent = store.get_node(root, parent_index).unwrap();
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
assert_eq!(res, None);
}
// for other ancestors of A and B, A and B should be lone leaves respectively
for depth in 9..64 {
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
let parent = store.get_node(root, parent_index).unwrap();
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
assert_eq!(res, Some((idx_a, val_a)));
}
for depth in 9..64 {
let parent_index = NodeIndex::make(depth, key_b >> (64 - depth));
let parent = store.get_node(root, parent_index).unwrap();
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
assert_eq!(res, Some((idx_b, val_b)));
}
// for any other node, find_lone_leaf() should return None as they have no leaf nodes
let parent_index = NodeIndex::make(16, 0b01010101_11111111);
let parent = store.get_node(root, parent_index).unwrap();
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
assert_eq!(res, None);
}
// SUBSET EXTRACTION
// ================================================================================================
#[test]
fn mstore_subset() {
// add a Merkle tree of depth 3 to the store
let mtree = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
let mut store = MerkleStore::default();
let empty_store_num_nodes = store.nodes.len();
store.extend(mtree.inner_nodes());
// build 3 subtrees contained within the above Merkle tree; note that subtree2 is a subset
// of subtree1
let subtree1 = MerkleTree::new(digests_to_words(&VALUES8[..4])).unwrap();
let subtree2 = MerkleTree::new(digests_to_words(&VALUES8[2..4])).unwrap();
let subtree3 = MerkleTree::new(digests_to_words(&VALUES8[6..])).unwrap();
// --- extract all 3 subtrees ---------------------------------------------
let substore = store.subset([subtree1.root(), subtree2.root(), subtree3.root()].iter());
// number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3
assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4);
// make sure paths that all subtrees are in the store
check_mstore_subtree(&substore, &subtree1);
check_mstore_subtree(&substore, &subtree2);
check_mstore_subtree(&substore, &subtree3);
// --- extract subtrees 1 and 3 -------------------------------------------
// this should give the same result as above as subtree2 is nested within subtree1
let substore = store.subset([subtree1.root(), subtree3.root()].iter());
// number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3
assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4);
// make sure paths that all subtrees are in the store
check_mstore_subtree(&substore, &subtree1);
check_mstore_subtree(&substore, &subtree2);
check_mstore_subtree(&substore, &subtree3);
}
fn check_mstore_subtree(store: &MerkleStore, subtree: &MerkleTree) {
for (i, value) in subtree.leaves() {
let index = NodeIndex::new(subtree.depth(), i).unwrap();
let path1 = store.get_path(subtree.root(), index).unwrap();
assert_eq!(*path1.value, *value);
let path2 = subtree.get_path(index).unwrap();
assert_eq!(path1.path, path2);
}
}
// SERIALIZATION
// ================================================================================================
#[cfg(feature = "std")]
#[test]
fn test_serialization() -> Result<(), Box<dyn Error>> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let store = MerkleStore::from(&mtree);
let decoded = MerkleStore::read_from_bytes(&store.to_bytes()).expect("deserialization failed");
assert_eq!(store, decoded);
Ok(())
}
// MERKLE RECORDER
// ================================================================================================
#[test]
fn test_recorder() {
// instantiate recorder from MerkleTree and SimpleSmt
let mtree = MerkleTree::new(digests_to_words(&VALUES4)).unwrap();
const TREE_DEPTH: u8 = 64;
let smtree = SimpleSmt::<TREE_DEPTH>::with_leaves(
KEYS8.into_iter().zip(VALUES8.into_iter().map(|x| x.into()).rev()),
)
.unwrap();
let mut recorder: RecordingMerkleStore =
mtree.inner_nodes().chain(smtree.inner_nodes()).collect();
// get nodes from both trees and make sure they are correct
let index_0 = NodeIndex::new(mtree.depth(), 0).unwrap();
let node = recorder.get_node(mtree.root(), index_0).unwrap();
assert_eq!(node, mtree.get_node(index_0).unwrap());
let index_1 = NodeIndex::new(TREE_DEPTH, 1).unwrap();
let node = recorder.get_node(smtree.root(), index_1).unwrap();
assert_eq!(node, smtree.get_node(index_1).unwrap());
// insert a value and assert that when we request it next time it is accurate
let new_value = [ZERO, ZERO, ONE, ONE].into();
let index_2 = NodeIndex::new(TREE_DEPTH, 2).unwrap();
let root = recorder.set_node(smtree.root(), index_2, new_value).unwrap().root;
assert_eq!(recorder.get_node(root, index_2).unwrap(), new_value);
// construct the proof
let rec_map = recorder.into_inner();
let (_, proof) = rec_map.finalize();
let merkle_store: MerkleStore = proof.into();
// make sure the proof contains all nodes from both trees
let node = merkle_store.get_node(mtree.root(), index_0).unwrap();
assert_eq!(node, mtree.get_node(index_0).unwrap());
let node = merkle_store.get_node(smtree.root(), index_1).unwrap();
assert_eq!(node, smtree.get_node(index_1).unwrap());
let node = merkle_store.get_node(smtree.root(), index_2).unwrap();
assert_eq!(
node,
smtree.get_leaf(&LeafIndex::<TREE_DEPTH>::try_from(index_2).unwrap()).into()
);
// assert that is doesnt contain nodes that were not recorded
let not_recorded_index = NodeIndex::new(TREE_DEPTH, 4).unwrap();
assert!(merkle_store.get_node(smtree.root(), not_recorded_index).is_err());
assert!(smtree.get_node(not_recorded_index).is_ok());
}

View file

@ -1,23 +0,0 @@
//! Pseudo-random element generation.
use rand::RngCore;
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
pub use winter_utils::Randomizable;
use crate::{Felt, FieldElement, Word, ZERO};
mod rpo;
mod rpx;
pub use rpo::RpoRandomCoin;
pub use rpx::RpxRandomCoin;
/// Pseudo-random element generator.
///
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
pub trait FeltRng: RngCore {
/// Draw, uniformly at random, a base field element.
fn draw_element(&mut self) -> Felt;
/// Draw, uniformly at random, a [Word].
fn draw_word(&mut self) -> Word;
}

View file

@ -1,296 +0,0 @@
use alloc::{string::ToString, vec::Vec};
use rand_core::impls;
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
use crate::{
hash::rpo::{Rpo256, RpoDigest},
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
// CONSTANTS
// ================================================================================================
const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
const RATE_START: usize = Rpo256::RATE_RANGE.start;
const RATE_END: usize = Rpo256::RATE_RANGE.end;
const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
// RPO RANDOM COIN
// ================================================================================================
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
/// described in <https://eprint.iacr.org/2011/499.pdf>.
///
/// The simplification is related to the following facts:
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
/// is possible because in our case we never reseed with more than 4 field elements.
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
/// material.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RpoRandomCoin {
state: [Felt; STATE_WIDTH],
current: usize,
}
impl RpoRandomCoin {
/// Returns a new [RpoRandomCoin] initialize with the specified seed.
pub fn new(seed: Word) -> Self {
let mut state = [ZERO; STATE_WIDTH];
for i in 0..HALF_RATE_WIDTH {
state[RATE_START + i] += seed[i];
}
// Absorb
Rpo256::apply_permutation(&mut state);
RpoRandomCoin { state, current: RATE_START }
}
/// Returns an [RpoRandomCoin] instantiated from the provided components.
///
/// # Panics
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
assert!(
(RATE_START..RATE_END).contains(&current),
"current value outside of valid range"
);
Self { state, current }
}
/// Returns components of this random coin.
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
(self.state, self.current)
}
/// Fills `dest` with random data.
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
<Self as RngCore>::fill_bytes(self, dest)
}
fn draw_basefield(&mut self) -> Felt {
if self.current == RATE_END {
Rpo256::apply_permutation(&mut self.state);
self.current = RATE_START;
}
self.current += 1;
self.state[self.current - 1]
}
}
// RANDOM COIN IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl RandomCoin for RpoRandomCoin {
type BaseField = Felt;
type Hasher = Rpo256;
fn new(seed: &[Self::BaseField]) -> Self {
let digest: Word = Rpo256::hash_elements(seed).into();
Self::new(digest)
}
fn reseed(&mut self, data: RpoDigest) {
// Reset buffer
self.current = RATE_START;
// Add the new seed material to the first half of the rate portion of the RPO state
let data: Word = data.into();
self.state[RATE_START] += data[0];
self.state[RATE_START + 1] += data[1];
self.state[RATE_START + 2] += data[2];
self.state[RATE_START + 3] += data[3];
// Absorb
Rpo256::apply_permutation(&mut self.state);
}
fn check_leading_zeros(&self, value: u64) -> u32 {
let value = Felt::new(value);
let mut state_tmp = self.state;
state_tmp[RATE_START] += value;
Rpo256::apply_permutation(&mut state_tmp);
let first_rate_element = state_tmp[RATE_START].as_int();
first_rate_element.trailing_zeros()
}
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
let ext_degree = E::EXTENSION_DEGREE;
let mut result = vec![ZERO; ext_degree];
for r in result.iter_mut().take(ext_degree) {
*r = self.draw_basefield();
}
let result = E::slice_from_base_elements(&result);
Ok(result[0])
}
fn draw_integers(
&mut self,
num_values: usize,
domain_size: usize,
nonce: u64,
) -> Result<Vec<usize>, RandomCoinError> {
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
assert!(num_values < domain_size, "number of values must be smaller than domain size");
// absorb the nonce
let nonce = Felt::new(nonce);
self.state[RATE_START] += nonce;
Rpo256::apply_permutation(&mut self.state);
// reset the buffer and move the next random element pointer to the second rate element.
// this is done as the first rate element will be "biased" via the provided `nonce` to
// contain some number of leading zeros.
self.current = RATE_START + 1;
// determine how many bits are needed to represent valid values in the domain
let v_mask = (domain_size - 1) as u64;
// draw values from PRNG until we get as many unique values as specified by num_queries
let mut values = Vec::new();
for _ in 0..1000 {
// get the next pseudo-random field element
let value = self.draw_basefield().as_int();
// use the mask to get a value within the range
let value = (value & v_mask) as usize;
values.push(value);
if values.len() == num_values {
break;
}
}
if values.len() < num_values {
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
}
Ok(values)
}
}
// FELT RNG IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl FeltRng for RpoRandomCoin {
fn draw_element(&mut self) -> Felt {
self.draw_basefield()
}
fn draw_word(&mut self) -> Word {
let mut output = [ZERO; 4];
for o in output.iter_mut() {
*o = self.draw_basefield();
}
output
}
}
// RNGCORE IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl RngCore for RpoRandomCoin {
fn next_u32(&mut self) -> u32 {
self.draw_basefield().as_int() as u32
}
fn next_u64(&mut self) -> u64 {
impls::next_u64_via_u32(self)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
impls::fill_bytes_via_next(self, dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
self.fill_bytes(dest);
Ok(())
}
}
// SERIALIZATION
// ------------------------------------------------------------------------------------------------
impl Serializable for RpoRandomCoin {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.state.iter().for_each(|v| v.write_into(target));
// casting to u8 is OK because `current` is always between 4 and 12.
target.write_u8(self.current as u8);
}
}
impl Deserializable for RpoRandomCoin {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let state = [
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
];
let current = source.read_u8()? as usize;
if !(RATE_START..RATE_END).contains(&current) {
return Err(DeserializationError::InvalidValue(
"current value outside of valid range".to_string(),
));
}
Ok(Self { state, current })
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
use crate::ONE;
#[test]
fn test_feltrng_felt() {
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
let output = rpocoin.draw_element();
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
let expected = rpocoin.draw_basefield();
assert_eq!(output, expected);
}
#[test]
fn test_feltrng_word() {
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
let output = rpocoin.draw_word();
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
let mut expected = [ZERO; 4];
for o in expected.iter_mut() {
*o = rpocoin.draw_basefield();
}
assert_eq!(output, expected);
}
#[test]
fn test_feltrng_serialization() {
let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
let bytes = coin1.to_bytes();
let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
assert_eq!(coin1, coin2);
}
}

View file

@ -1,294 +0,0 @@
use alloc::{string::ToString, vec::Vec};
use rand_core::impls;
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
use crate::{
hash::rpx::{Rpx256, RpxDigest},
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
// CONSTANTS
// ================================================================================================
const STATE_WIDTH: usize = Rpx256::STATE_WIDTH;
const RATE_START: usize = Rpx256::RATE_RANGE.start;
const RATE_END: usize = Rpx256::RATE_RANGE.end;
const HALF_RATE_WIDTH: usize = (Rpx256::RATE_RANGE.end - Rpx256::RATE_RANGE.start) / 2;
// RPX RANDOM COIN
// ================================================================================================
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
/// described in <https://eprint.iacr.org/2011/499.pdf>.
///
/// The simplification is related to the following facts:
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
/// is possible because in our case we never reseed with more than 4 field elements.
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
/// material.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RpxRandomCoin {
state: [Felt; STATE_WIDTH],
current: usize,
}
impl RpxRandomCoin {
/// Returns a new [RpxRandomCoin] initialize with the specified seed.
pub fn new(seed: Word) -> Self {
let mut state = [ZERO; STATE_WIDTH];
for i in 0..HALF_RATE_WIDTH {
state[RATE_START + i] += seed[i];
}
// Absorb
Rpx256::apply_permutation(&mut state);
RpxRandomCoin { state, current: RATE_START }
}
/// Returns an [RpxRandomCoin] instantiated from the provided components.
///
/// # Panics
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
assert!(
(RATE_START..RATE_END).contains(&current),
"current value outside of valid range"
);
Self { state, current }
}
/// Returns components of this random coin.
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
(self.state, self.current)
}
/// Fills `dest` with random data.
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
<Self as RngCore>::fill_bytes(self, dest)
}
fn draw_basefield(&mut self) -> Felt {
if self.current == RATE_END {
Rpx256::apply_permutation(&mut self.state);
self.current = RATE_START;
}
self.current += 1;
self.state[self.current - 1]
}
}
// RANDOM COIN IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl RandomCoin for RpxRandomCoin {
type BaseField = Felt;
type Hasher = Rpx256;
fn new(seed: &[Self::BaseField]) -> Self {
let digest: Word = Rpx256::hash_elements(seed).into();
Self::new(digest)
}
fn reseed(&mut self, data: RpxDigest) {
// Reset buffer
self.current = RATE_START;
// Add the new seed material to the first half of the rate portion of the RPX state
let data: Word = data.into();
self.state[RATE_START] += data[0];
self.state[RATE_START + 1] += data[1];
self.state[RATE_START + 2] += data[2];
self.state[RATE_START + 3] += data[3];
// Absorb
Rpx256::apply_permutation(&mut self.state);
}
fn check_leading_zeros(&self, value: u64) -> u32 {
let value = Felt::new(value);
let mut state_tmp = self.state;
state_tmp[RATE_START] += value;
Rpx256::apply_permutation(&mut state_tmp);
let first_rate_element = state_tmp[RATE_START].as_int();
first_rate_element.trailing_zeros()
}
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
let ext_degree = E::EXTENSION_DEGREE;
let mut result = vec![ZERO; ext_degree];
for r in result.iter_mut().take(ext_degree) {
*r = self.draw_basefield();
}
let result = E::slice_from_base_elements(&result);
Ok(result[0])
}
fn draw_integers(
&mut self,
num_values: usize,
domain_size: usize,
nonce: u64,
) -> Result<Vec<usize>, RandomCoinError> {
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
assert!(num_values < domain_size, "number of values must be smaller than domain size");
// absorb the nonce
let nonce = Felt::new(nonce);
self.state[RATE_START] += nonce;
Rpx256::apply_permutation(&mut self.state);
// reset the buffer
self.current = RATE_START;
// determine how many bits are needed to represent valid values in the domain
let v_mask = (domain_size - 1) as u64;
// draw values from PRNG until we get as many unique values as specified by num_queries
let mut values = Vec::new();
for _ in 0..1000 {
// get the next pseudo-random field element
let value = self.draw_basefield().as_int();
// use the mask to get a value within the range
let value = (value & v_mask) as usize;
values.push(value);
if values.len() == num_values {
break;
}
}
if values.len() < num_values {
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
}
Ok(values)
}
}
// FELT RNG IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl FeltRng for RpxRandomCoin {
fn draw_element(&mut self) -> Felt {
self.draw_basefield()
}
fn draw_word(&mut self) -> Word {
let mut output = [ZERO; 4];
for o in output.iter_mut() {
*o = self.draw_basefield();
}
output
}
}
// RNGCORE IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl RngCore for RpxRandomCoin {
fn next_u32(&mut self) -> u32 {
self.draw_basefield().as_int() as u32
}
fn next_u64(&mut self) -> u64 {
impls::next_u64_via_u32(self)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
impls::fill_bytes_via_next(self, dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
self.fill_bytes(dest);
Ok(())
}
}
// SERIALIZATION
// ------------------------------------------------------------------------------------------------
impl Serializable for RpxRandomCoin {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.state.iter().for_each(|v| v.write_into(target));
// casting to u8 is OK because `current` is always between 4 and 12.
target.write_u8(self.current as u8);
}
}
impl Deserializable for RpxRandomCoin {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let state = [
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
];
let current = source.read_u8()? as usize;
if !(RATE_START..RATE_END).contains(&current) {
return Err(DeserializationError::InvalidValue(
"current value outside of valid range".to_string(),
));
}
Ok(Self { state, current })
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{Deserializable, FeltRng, RpxRandomCoin, Serializable, ZERO};
use crate::ONE;
#[test]
fn test_feltrng_felt() {
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
let output = rpxcoin.draw_element();
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
let expected = rpxcoin.draw_basefield();
assert_eq!(output, expected);
}
#[test]
fn test_feltrng_word() {
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
let output = rpxcoin.draw_word();
let mut rpocoin = RpxRandomCoin::new([ZERO; 4]);
let mut expected = [ZERO; 4];
for o in expected.iter_mut() {
*o = rpocoin.draw_basefield();
}
assert_eq!(output, expected);
}
#[test]
fn test_feltrng_serialization() {
let coin1 = RpxRandomCoin::from_parts([ONE; 12], 5);
let bytes = coin1.to_bytes();
let coin2 = RpxRandomCoin::read_from_bytes(&bytes).unwrap();
assert_eq!(coin1, coin2);
}
}

View file

@ -1,402 +0,0 @@
use alloc::{
boxed::Box,
collections::{BTreeMap, BTreeSet},
};
use core::cell::RefCell;
// KEY-VALUE MAP TRAIT
// ================================================================================================
/// A trait that defines the interface for a key-value map.
pub trait KvMap<K: Ord + Clone, V: Clone>:
Extend<(K, V)> + FromIterator<(K, V)> + IntoIterator<Item = (K, V)>
{
fn get(&self, key: &K) -> Option<&V>;
fn contains_key(&self, key: &K) -> bool;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn insert(&mut self, key: K, value: V) -> Option<V>;
fn remove(&mut self, key: &K) -> Option<V>;
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_>;
}
// BTREE MAP `KvMap` IMPLEMENTATION
// ================================================================================================
impl<K: Ord + Clone, V: Clone> KvMap<K, V> for BTreeMap<K, V> {
fn get(&self, key: &K) -> Option<&V> {
self.get(key)
}
fn contains_key(&self, key: &K) -> bool {
self.contains_key(key)
}
fn len(&self) -> usize {
self.len()
}
fn insert(&mut self, key: K, value: V) -> Option<V> {
self.insert(key, value)
}
fn remove(&mut self, key: &K) -> Option<V> {
self.remove(key)
}
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
Box::new(self.iter())
}
}
// RECORDING MAP
// ================================================================================================
/// A [RecordingMap] that records read requests to the underlying key-value map.
///
/// The data recorder is used to generate a proof for read requests.
///
/// The [RecordingMap] is composed of three parts:
/// - `data`: which contains the current set of key-value pairs in the map.
/// - `updates`: which tracks keys for which values have been changed since the map was
/// instantiated. updates include both insertions, removals and updates of values under existing
/// keys.
/// - `trace`: which contains the key-value pairs from the original data which have been accesses
/// since the map was instantiated.
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub struct RecordingMap<K, V> {
data: BTreeMap<K, V>,
updates: BTreeSet<K>,
trace: RefCell<BTreeMap<K, V>>,
}
impl<K: Ord + Clone, V: Clone> RecordingMap<K, V> {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns a new [RecordingMap] instance initialized with the provided key-value pairs.
/// ([BTreeMap]).
pub fn new(init: impl IntoIterator<Item = (K, V)>) -> Self {
RecordingMap {
data: init.into_iter().collect(),
updates: BTreeSet::new(),
trace: RefCell::new(BTreeMap::new()),
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
pub fn inner(&self) -> &BTreeMap<K, V> {
&self.data
}
// FINALIZER
// --------------------------------------------------------------------------------------------
/// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first
/// element of the tuple is a map that represents the state of the map at the time `.finalize()`
/// is called. The second element contains the key-value pairs from the initial data set that
/// were read during recording.
pub fn finalize(self) -> (BTreeMap<K, V>, BTreeMap<K, V>) {
(self.data, self.trace.take())
}
// TEST HELPERS
// --------------------------------------------------------------------------------------------
#[cfg(test)]
pub fn trace_len(&self) -> usize {
self.trace.borrow().len()
}
#[cfg(test)]
pub fn updates_len(&self) -> usize {
self.updates.len()
}
}
impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns a reference to the value associated with the given key if the value exists.
///
/// If the key is part of the initial data set, the key access is recorded.
fn get(&self, key: &K) -> Option<&V> {
self.data.get(key).inspect(|&value| {
if !self.updates.contains(key) {
self.trace.borrow_mut().insert(key.clone(), value.clone());
}
})
}
/// Returns a boolean to indicate whether the given key exists in the data set.
///
/// If the key is part of the initial data set, the key access is recorded.
fn contains_key(&self, key: &K) -> bool {
self.get(key).is_some()
}
/// Returns the number of key-value pairs in the data set.
fn len(&self) -> usize {
self.data.len()
}
// MUTATORS
// --------------------------------------------------------------------------------------------
/// Inserts a key-value pair into the data set.
///
/// If the key already exists in the data set, the value is updated and the old value is
/// returned.
fn insert(&mut self, key: K, value: V) -> Option<V> {
let new_update = self.updates.insert(key.clone());
self.data.insert(key.clone(), value).inspect(|old_value| {
if new_update {
self.trace.borrow_mut().insert(key, old_value.clone());
}
})
}
/// Removes a key-value pair from the data set.
///
/// If the key exists in the data set, the old value is returned.
fn remove(&mut self, key: &K) -> Option<V> {
self.data.remove(key).inspect(|old_value| {
let new_update = self.updates.insert(key.clone());
if new_update {
self.trace.borrow_mut().insert(key.clone(), old_value.clone());
}
})
}
// ITERATION
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the key-value pairs in the data set.
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
Box::new(self.data.iter())
}
}
impl<K: Clone + Ord, V: Clone> Extend<(K, V)> for RecordingMap<K, V> {
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
iter.into_iter().for_each(move |(k, v)| {
self.insert(k, v);
});
}
}
impl<K: Clone + Ord, V: Clone> FromIterator<(K, V)> for RecordingMap<K, V> {
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
Self::new(iter)
}
}
impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
type Item = (K, V);
type IntoIter = alloc::collections::btree_map::IntoIter<K, V>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::*;
const ITEMS: [(u64, u64); 5] = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)];
#[test]
fn test_get_item() {
// instantiate a recording map
let map = RecordingMap::new(ITEMS.to_vec());
// get a few items
let get_items = [0, 1, 2];
for key in get_items.iter() {
map.get(key);
}
// convert the map into a proof
let (_, proof) = map.finalize();
// check that the proof contains the expected values
for (key, value) in ITEMS.iter() {
match get_items.contains(key) {
true => assert_eq!(proof.get(key), Some(value)),
false => assert_eq!(proof.get(key), None),
}
}
}
#[test]
fn test_contains_key() {
// instantiate a recording map
let map = RecordingMap::new(ITEMS.to_vec());
// check if the map contains a few items
let get_items = [0, 1, 2];
for key in get_items.iter() {
map.contains_key(key);
}
// convert the map into a proof
let (_, proof) = map.finalize();
// check that the proof contains the expected values
for (key, _) in ITEMS.iter() {
match get_items.contains(key) {
true => assert!(proof.contains_key(key)),
false => assert!(!proof.contains_key(key)),
}
}
}
#[test]
fn test_len() {
// instantiate a recording map
let mut map = RecordingMap::new(ITEMS.to_vec());
// length of the map should be equal to the number of items
assert_eq!(map.len(), ITEMS.len());
// inserting entry with key that already exists should not change the length, but it does
// add entries to the trace and update sets
map.insert(4, 5);
assert_eq!(map.len(), ITEMS.len());
assert_eq!(map.trace_len(), 1);
assert_eq!(map.updates_len(), 1);
// inserting entry with new key should increase the length; it should also record the key
// as an updated key, but the trace length does not change since old values were not touched
map.insert(5, 5);
assert_eq!(map.len(), ITEMS.len() + 1);
assert_eq!(map.trace_len(), 1);
assert_eq!(map.updates_len(), 2);
// get some items so that they are saved in the trace; this should record original items
// in the trace, but should not affect the set of updates
let get_items = [0, 1, 2];
for key in get_items.iter() {
map.contains_key(key);
}
assert_eq!(map.trace_len(), 4);
assert_eq!(map.updates_len(), 2);
// read the same items again, this should not have any effect on either length, trace, or
// the set of updates
let get_items = [0, 1, 2];
for key in get_items.iter() {
map.contains_key(key);
}
assert_eq!(map.trace_len(), 4);
assert_eq!(map.updates_len(), 2);
// read a newly inserted item; this should not affect either length, trace, or the set of
// updates
let _val = map.get(&5).unwrap();
assert_eq!(map.trace_len(), 4);
assert_eq!(map.updates_len(), 2);
// update a newly inserted item; this should not affect either length, trace, or the set
// of updates
map.insert(5, 11);
assert_eq!(map.trace_len(), 4);
assert_eq!(map.updates_len(), 2);
// Note: The length reported by the proof will be different to the length originally
// reported by the map.
let (_, proof) = map.finalize();
// length of the proof should be equal to get_items + 1. The extra item is the original
// value at key = 4u64
assert_eq!(proof.len(), get_items.len() + 1);
}
#[test]
fn test_iter() {
let mut map = RecordingMap::new(ITEMS.to_vec());
assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y))));
// when inserting entry with key that already exists the iterator should return the new
// value
let new_value = 5;
map.insert(4, new_value);
assert_eq!(map.iter().count(), ITEMS.len());
assert!(map.iter().all(|(x, y)| if x == &4 {
y == &new_value
} else {
ITEMS.contains(&(*x, *y))
}));
}
#[test]
fn test_is_empty() {
// instantiate an empty recording map
let empty_map: RecordingMap<u64, u64> = RecordingMap::default();
assert!(empty_map.is_empty());
// instantiate a non-empty recording map
let map = RecordingMap::new(ITEMS.to_vec());
assert!(!map.is_empty());
}
#[test]
fn test_remove() {
let mut map = RecordingMap::new(ITEMS.to_vec());
// remove an item that exists
let key = 0;
let value = map.remove(&key).unwrap();
assert_eq!(value, ITEMS[0].1);
assert_eq!(map.len(), ITEMS.len() - 1);
assert_eq!(map.trace_len(), 1);
assert_eq!(map.updates_len(), 1);
// add the item back and then remove it again
let key = 0;
let value = 0;
map.insert(key, value);
let value = map.remove(&key).unwrap();
assert_eq!(value, 0);
assert_eq!(map.len(), ITEMS.len() - 1);
assert_eq!(map.trace_len(), 1);
assert_eq!(map.updates_len(), 1);
// remove an item that does not exist
let key = 100;
let value = map.remove(&key);
assert_eq!(value, None);
assert_eq!(map.len(), ITEMS.len() - 1);
assert_eq!(map.trace_len(), 1);
assert_eq!(map.updates_len(), 1);
// insert a new item and then remove it
let key = 100;
let value = 100;
map.insert(key, value);
let value = map.remove(&key).unwrap();
assert_eq!(value, 100);
assert_eq!(map.len(), ITEMS.len() - 1);
assert_eq!(map.trace_len(), 1);
assert_eq!(map.updates_len(), 2);
// convert the map into a proof
let (_, proof) = map.finalize();
// check that the proof contains the expected values
for (key, value) in ITEMS.iter() {
match key {
0 => assert_eq!(proof.get(key), Some(value)),
_ => assert_eq!(proof.get(key), None),
}
}
}
}

Some files were not shown because too many files have changed in this diff Show more