Compare commits
No commits in common. "9fe2673a3362ae546da0ae1dd073eb8b217dec7a" and "bbe11964b1bd1f57bbebc406d1bf7e0e04ac2ac4" have entirely different histories.
9fe2673a33
...
bbe11964b1
101 changed files with 27088 additions and 1 deletions
3
.config/nextest.toml
Normal file
3
.config/nextest.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
[profile.default]
|
||||
failure-output = "immediate-final"
|
||||
fail-fast = false
|
20
.editorconfig
Normal file
20
.editorconfig
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Documentation available at editorconfig.org
|
||||
|
||||
root=true
|
||||
|
||||
[*]
|
||||
ident_style = space
|
||||
ident_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.rs]
|
||||
max_line_length = 100
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[*.yml]
|
||||
ident_size = 2
|
9
.github/pull_request_template.md
vendored
Normal file
9
.github/pull_request_template.md
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
## Describe your changes
|
||||
|
||||
|
||||
## Checklist before requesting a review
|
||||
- Repo forked and branch created from `next` according to naming convention.
|
||||
- Commit messages and codestyle follow [conventions](./CONTRIBUTING.md).
|
||||
- Relevant issues are linked in the PR description.
|
||||
- Tests added for new functionality.
|
||||
- Documentation/comments updated according to changes.
|
25
.github/workflows/build.yml
vendored
Normal file
25
.github/workflows/build.yml
vendored
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Runs build related jobs.
|
||||
|
||||
name: build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, next]
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
no-std:
|
||||
name: Build for no-std
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Build for no-std
|
||||
run: |
|
||||
rustup update --no-self-update ${{ matrix.toolchain }}
|
||||
rustup target add wasm32-unknown-unknown
|
||||
make build-no-std
|
23
.github/workflows/changelog.yml
vendored
Normal file
23
.github/workflows/changelog.yml
vendored
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Runs changelog related jobs.
|
||||
# CI job heavily inspired by: https://github.com/tarides/changelog-check-action
|
||||
|
||||
name: changelog
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize, labeled, unlabeled]
|
||||
|
||||
jobs:
|
||||
changelog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@main
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Check for changes in changelog
|
||||
env:
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
NO_CHANGELOG_LABEL: ${{ contains(github.event.pull_request.labels.*.name, 'no changelog') }}
|
||||
run: ./scripts/check-changelog.sh "${{ inputs.changelog }}"
|
||||
shell: bash
|
53
.github/workflows/lint.yml
vendored
Normal file
53
.github/workflows/lint.yml
vendored
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Runs linting related jobs.
|
||||
|
||||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, next]
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
name: clippy nightly on ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Clippy
|
||||
run: |
|
||||
rustup update --no-self-update nightly
|
||||
rustup +nightly component add clippy
|
||||
make clippy
|
||||
|
||||
rustfmt:
|
||||
name: rustfmt check nightly on ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Rustfmt
|
||||
run: |
|
||||
rustup update --no-self-update nightly
|
||||
rustup +nightly component add rustfmt
|
||||
make format-check
|
||||
|
||||
doc:
|
||||
name: doc stable on ubuntu-latest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- name: Build docs
|
||||
run: |
|
||||
rustup update --no-self-update
|
||||
make doc
|
||||
|
||||
version:
|
||||
name: check rust version consistency
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
with:
|
||||
profile: minimal
|
||||
override: true
|
||||
- name: check rust versions
|
||||
run: ./scripts/check-rust-version.sh
|
28
.github/workflows/test.yml
vendored
Normal file
28
.github/workflows/test.yml
vendored
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Runs test related jobs.
|
||||
|
||||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, next]
|
||||
pull_request:
|
||||
types: [opened, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
|
||||
runs-on: ${{matrix.os}}-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
toolchain: [stable, nightly]
|
||||
os: [ubuntu]
|
||||
args: [default, smt-hashmaps, no-std]
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@main
|
||||
- uses: taiki-e/install-action@nextest
|
||||
- name: Perform tests
|
||||
run: |
|
||||
rustup update --no-self-update ${{matrix.toolchain}}
|
||||
make test-${{matrix.args}}
|
12
.gitignore
vendored
Normal file
12
.gitignore
vendored
Normal file
|
@ -0,0 +1,12 @@
|
|||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# Generated by cmake
|
||||
cmake-build-*
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
34
.pre-commit-config.yaml
Normal file
34
.pre-commit-config.yaml
Normal file
|
@ -0,0 +1,34 @@
|
|||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-json
|
||||
- id: check-toml
|
||||
- id: pretty-format-json
|
||||
- id: check-added-large-files
|
||||
- id: check-case-conflict
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-merge-conflict
|
||||
- id: detect-private-key
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: lint
|
||||
name: Make lint
|
||||
stages: [commit]
|
||||
language: rust
|
||||
entry: make lint
|
||||
- id: doc
|
||||
name: Make doc
|
||||
stages: [commit]
|
||||
language: rust
|
||||
entry: make doc
|
||||
- id: check
|
||||
name: Make check
|
||||
stages: [commit]
|
||||
language: rust
|
||||
entry: make check
|
186
CHANGELOG.md
Normal file
186
CHANGELOG.md
Normal file
|
@ -0,0 +1,186 @@
|
|||
## 0.14.0 (TBD)
|
||||
|
||||
- [BREAKING] Increment minimum supported Rust version to 1.84.
|
||||
- Removed duplicated check in RpoFalcon512 verification (#368).
|
||||
- Added parallel implementation of `Smt::compute_mutations` with better performance (#365).
|
||||
- Implemented parallel leaf hashing in `Smt::process_sorted_pairs_to_leaves` (#365).
|
||||
- [BREAKING] Updated Winterfell dependency to v0.12 (#374).
|
||||
- Added debug-only duplicate column check in `build_subtree` (#378).
|
||||
|
||||
## 0.13.3 (2025-02-12)
|
||||
|
||||
- Implement `PartialSmt` (#372).
|
||||
|
||||
## 0.13.2 (2025-01-24)
|
||||
|
||||
- Made `InnerNode` and `NodeMutation` public. Implemented (de)serialization of `LeafIndex` (#367).
|
||||
|
||||
## 0.13.1 (2024-12-26)
|
||||
|
||||
- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
|
||||
|
||||
## 0.13.0 (2024-11-24)
|
||||
|
||||
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
|
||||
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
|
||||
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
|
||||
- Added support for hashmaps in `Smt` and `SimpleSmt` which gives up to 10x boost in some operations (#363).
|
||||
|
||||
## 0.12.0 (2024-10-30)
|
||||
|
||||
- [BREAKING] Updated Winterfell dependency to v0.10 (#338).
|
||||
- Added parallel implementation of `Smt::with_entries()` with significantly better performance when the `concurrent` feature is enabled (#341).
|
||||
|
||||
## 0.11.0 (2024-10-17)
|
||||
|
||||
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
|
||||
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
|
||||
- Standardized CI and Makefile across Miden repos (#323).
|
||||
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
|
||||
- Changed padding rule for RPO/RPX hash functions (#318).
|
||||
- [BREAKING] Changed return value of the `Mmr::verify()` and `MerklePath::verify()` from `bool` to `Result<>` (#335).
|
||||
- Added `is_empty()` functions to the `SimpleSmt` and `Smt` structures. Added `EMPTY_ROOT` constant to the `SparseMerkleTree` trait (#337).
|
||||
|
||||
## 0.10.3 (2024-09-25)
|
||||
|
||||
- Implement `get_size_hint` for `Smt` (#331).
|
||||
|
||||
## 0.10.2 (2024-09-25)
|
||||
|
||||
- Implement `get_size_hint` for `RpoDigest` and `RpxDigest` and expose constants for their serialized size (#330).
|
||||
|
||||
## 0.10.1 (2024-09-13)
|
||||
|
||||
- Added `Serializable` and `Deserializable` implementations for `PartialMmr` and `InOrderIndex` (#329).
|
||||
|
||||
## 0.10.0 (2024-08-06)
|
||||
|
||||
- Added more `RpoDigest` and `RpxDigest` conversions (#311).
|
||||
- [BREAKING] Migrated to Winterfell v0.9 (#315).
|
||||
- Fixed encoding of Falcon secret key (#319).
|
||||
|
||||
## 0.9.3 (2024-04-24)
|
||||
|
||||
- Added `RpxRandomCoin` struct (#307).
|
||||
|
||||
## 0.9.2 (2024-04-21)
|
||||
|
||||
- Implemented serialization for the `Smt` struct (#304).
|
||||
- Fixed a bug in Falcon signature generation (#305).
|
||||
|
||||
## 0.9.1 (2024-04-02)
|
||||
|
||||
- Added `num_leaves()` method to `SimpleSmt` (#302).
|
||||
|
||||
## 0.9.0 (2024-03-24)
|
||||
|
||||
- [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290).
|
||||
- [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285).
|
||||
- [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299).
|
||||
|
||||
# 0.8.4 (2024-03-17)
|
||||
|
||||
- Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros).
|
||||
|
||||
# 0.8.3 (2024-03-17)
|
||||
|
||||
- Re-added unintentionally removed re-exported liballoc macros (#292).
|
||||
|
||||
# 0.8.2 (2024-03-17)
|
||||
|
||||
- Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290).
|
||||
|
||||
## 0.8.1 (2024-02-21)
|
||||
|
||||
- Fixed clippy warnings (#280)
|
||||
|
||||
## 0.8.0 (2024-02-14)
|
||||
|
||||
- Implemented the `PartialMmr` data structure (#195).
|
||||
- Implemented RPX hash function (#201).
|
||||
- Added `FeltRng` and `RpoRandomCoin` (#237).
|
||||
- Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
|
||||
- Added `inner_nodes()` method to `PartialMmr` (#238).
|
||||
- Improved `PartialMmr::apply_delta()` (#242).
|
||||
- Refactored `SimpleSmt` struct (#245).
|
||||
- Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
|
||||
- Updated Winterfell dependency to v0.8 (#275).
|
||||
|
||||
## 0.7.1 (2023-10-10)
|
||||
|
||||
- Fixed RPO Falcon signature build on Windows.
|
||||
|
||||
## 0.7.0 (2023-10-05)
|
||||
|
||||
- Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
|
||||
- Implemented clearing of nodes in `TieredSmt` (#173).
|
||||
- Added ability to generate inclusion proofs for `TieredSmt` (#174).
|
||||
- Implemented Falcon DSA (#179).
|
||||
- Added conditional `serde`` support for various structs (#180).
|
||||
- Implemented benchmarking for `TieredSmt` (#182).
|
||||
- Added more leaf traversal methods for `MerkleStore` (#185).
|
||||
- Added SVE acceleration for RPO hash function (#189).
|
||||
|
||||
## 0.6.0 (2023-06-25)
|
||||
|
||||
- [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
|
||||
- [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
|
||||
- Added initial implementation of `PartialMerkleTree` (#156).
|
||||
|
||||
## 0.5.0 (2023-05-26)
|
||||
|
||||
- Implemented `TieredSmt` (#152, #153).
|
||||
- Implemented ability to extract a subset of a `MerkleStore` (#151).
|
||||
- Cleaned up `SimpleSmt` interface (#149).
|
||||
- Decoupled hashing and padding of peaks in `Mmr` (#148).
|
||||
- Added `inner_nodes()` to `MerkleStore` (#146).
|
||||
|
||||
## 0.4.0 (2023-04-21)
|
||||
|
||||
- Exported `MmrProof` from the crate (#137).
|
||||
- Allowed merging of leaves in `MerkleStore` (#138).
|
||||
- [BREAKING] Refactored how existing data structures are added to `MerkleStore` (#139).
|
||||
|
||||
## 0.3.0 (2023-04-08)
|
||||
|
||||
- Added `depth` parameter to SMT constructors in `MerkleStore` (#115).
|
||||
- Optimized MMR peak hashing for Miden VM (#120).
|
||||
- Added `get_leaf_depth` method to `MerkleStore` (#119).
|
||||
- Added inner node iterators to `MerkleTree`, `SimpleSmt`, and `Mmr` (#117, #118, #121).
|
||||
|
||||
## 0.2.0 (2023-03-24)
|
||||
|
||||
- Implemented `Mmr` and related structs (#67).
|
||||
- Implemented `MerkleStore` (#93, #94, #95, #107 #112).
|
||||
- Added benchmarks for `MerkleStore` vs. other structs (#97).
|
||||
- Added Merkle path containers (#99).
|
||||
- Fixed depth handling in `MerklePathSet` (#110).
|
||||
- Updated Winterfell dependency to v0.6.
|
||||
|
||||
## 0.1.4 (2023-02-22)
|
||||
|
||||
- Re-export winter-crypto Hasher, Digest & ElementHasher (#72)
|
||||
|
||||
## 0.1.3 (2023-02-20)
|
||||
|
||||
- Updated Winterfell dependency to v0.5.1 (#68)
|
||||
|
||||
## 0.1.2 (2023-02-17)
|
||||
|
||||
- Fixed `Rpo256::hash` pad that was panicking on input (#44)
|
||||
- Added `MerklePath` wrapper to encapsulate Merkle opening verification and root computation (#53)
|
||||
- Added `NodeIndex` Merkle wrapper to encapsulate Merkle tree traversal and mappings (#54)
|
||||
|
||||
## 0.1.1 (2023-02-06)
|
||||
|
||||
- Introduced `merge_in_domain` for the RPO hash function, to allow using a specified domain value in the second capacity register when hashing two digests together.
|
||||
- Added a simple sparse Merkle tree implementation.
|
||||
- Added re-exports of Winterfell RandomCoin and RandomCoinError.
|
||||
|
||||
## 0.1.0 (2022-12-02)
|
||||
|
||||
- Initial release on crates.io containing the cryptographic primitives used in Miden VM and the Miden Rollup.
|
||||
- Hash module with the BLAKE3 and Rescue Prime Optimized hash functions.
|
||||
- BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output.
|
||||
- RPO is implemented with 256-bit output.
|
||||
- Merkle module, with a set of data structures related to Merkle trees, implemented using the RPO hash function.
|
108
CONTRIBUTING.md
Normal file
108
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Contributing to Crypto
|
||||
|
||||
#### First off, thanks for taking the time to contribute!
|
||||
|
||||
We want to make contributing to this project as easy and transparent as possible, whether it's:
|
||||
|
||||
- Reporting a [bug](https://github.com/0xPolygonMiden/crypto/issues/new)
|
||||
- Taking part in [discussions](https://github.com/0xPolygonMiden/crypto/discussions)
|
||||
- Submitting a [fix](https://github.com/0xPolygonMiden/crypto/pulls)
|
||||
- Proposing new [features](https://github.com/0xPolygonMiden/crypto/issues/new)
|
||||
|
||||
|
||||
|
||||
## Flow
|
||||
We are using [Github Flow](https://docs.github.com/en/get-started/quickstart/github-flow), so all code changes happen through pull requests from a [forked repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo).
|
||||
|
||||
### Branching
|
||||
- The current active branch is `next`. Every branch with a fix/feature must be forked from `next`.
|
||||
|
||||
- The branch name should contain a short issue/feature description separated with hyphens [(kebab-case)](https://en.wikipedia.org/wiki/Letter_case#Kebab_case).
|
||||
|
||||
For example, if the issue title is `Fix functionality X in component Y` then the branch name will be something like: `fix-x-in-y`.
|
||||
|
||||
- New branch should be rebased from `next` before submitting a PR in case there have been changes to avoid merge commits.
|
||||
i.e. this branches state:
|
||||
```
|
||||
A---B---C fix-x-in-y
|
||||
/
|
||||
D---E---F---G next
|
||||
| |
|
||||
(F, G) changes happened after `fix-x-in-y` forked
|
||||
```
|
||||
|
||||
should become this after rebase:
|
||||
|
||||
|
||||
```
|
||||
A'--B'--C' fix-x-in-y
|
||||
/
|
||||
D---E---F---G next
|
||||
```
|
||||
|
||||
|
||||
More about rebase [here](https://git-scm.com/docs/git-rebase) and [here](https://www.atlassian.com/git/tutorials/rewriting-history/git-rebase#:~:text=What%20is%20git%20rebase%3F,of%20a%20feature%20branching%20workflow.)
|
||||
|
||||
|
||||
### Commit messages
|
||||
- Commit messages should be written in a short, descriptive manner and be prefixed with tags for the change type and scope (if possible) according to the [semantic commit](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716) scheme.
|
||||
For example, a new change to the AIR crate might have the following message: `feat(air): add constraints for the decoder`
|
||||
|
||||
- Also squash commits to logically separated, distinguishable stages to keep git log clean:
|
||||
```
|
||||
7hgf8978g9... Added A to X \
|
||||
\ (squash)
|
||||
gh354354gh... oops, typo --- * ---------> 9fh1f51gh7... feat(X): add A && B
|
||||
/
|
||||
85493g2458... Added B to X /
|
||||
|
||||
|
||||
789fdfffdf... Fixed D in Y \
|
||||
\ (squash)
|
||||
787g8fgf78... blah blah --- * ---------> 4070df6f00... fix(Y): fixed D && C
|
||||
/
|
||||
9080gf6567... Fixed C in Y /
|
||||
```
|
||||
|
||||
### Code Style and Documentation
|
||||
- For documentation in the codebase, we follow the [rustdoc](https://doc.rust-lang.org/rust-by-example/meta/doc.html) convention with no more than 100 characters per line.
|
||||
- For code sections, we use code separators like the following to a width of 100 characters::
|
||||
```
|
||||
// CODE SECTION HEADER
|
||||
// ================================================================================
|
||||
```
|
||||
|
||||
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's preferable to run linting locally before push:
|
||||
```
|
||||
cargo fix --allow-staged --allow-dirty --all-targets --all-features; cargo fmt; cargo clippy --workspace --all-targets --all-features -- -D warnings
|
||||
```
|
||||
|
||||
### Versioning
|
||||
We use [semver](https://semver.org/) naming convention.
|
||||
|
||||
|
||||
|
||||
## Pre-PR checklist
|
||||
1. Repo forked and branch created from `next` according to the naming convention.
|
||||
2. Commit messages and code style follow conventions.
|
||||
3. Tests added for new functionality.
|
||||
4. Documentation/comments updated for all changes according to our documentation convention.
|
||||
5. Clippy and Rustfmt linting passed.
|
||||
6. New branch rebased from `next`.
|
||||
|
||||
|
||||
|
||||
## Write bug reports with detail, background, and sample code
|
||||
|
||||
**Great Bug Reports** tend to have:
|
||||
|
||||
- A quick summary and/or background
|
||||
- Steps to reproduce
|
||||
- What you expected would happen
|
||||
- What actually happens
|
||||
- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
|
||||
|
||||
|
||||
|
||||
## Any contributions you make will be under the MIT Software License
|
||||
In short, when you submit code changes, your submissions are understood to be under the same [MIT License](http://choosealicense.com/licenses/mit/) that covers the project. Feel free to contact the maintainers if that's a concern.
|
1297
Cargo.lock
generated
Normal file
1297
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
93
Cargo.toml
Normal file
93
Cargo.toml
Normal file
|
@ -0,0 +1,93 @@
|
|||
[package]
|
||||
name = "miden-crypto"
|
||||
version = "0.14.0"
|
||||
description = "Miden Cryptographic primitives"
|
||||
authors = ["miden contributors"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
repository = "https://github.com/0xPolygonMiden/crypto"
|
||||
documentation = "https://docs.rs/miden-crypto/0.14.0"
|
||||
categories = ["cryptography", "no-std"]
|
||||
keywords = ["miden", "crypto", "hash", "merkle"]
|
||||
edition = "2021"
|
||||
rust-version = "1.84"
|
||||
|
||||
[[bin]]
|
||||
name = "miden-crypto"
|
||||
path = "src/main.rs"
|
||||
bench = false
|
||||
doctest = false
|
||||
required-features = ["executable"]
|
||||
|
||||
[[bench]]
|
||||
name = "hash"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "smt"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "smt-subtree"
|
||||
harness = false
|
||||
required-features = ["internal"]
|
||||
|
||||
[[bench]]
|
||||
name = "merkle"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "smt-with-entries"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "store"
|
||||
harness = false
|
||||
|
||||
[features]
|
||||
concurrent = ["dep:rayon", "hashbrown?/rayon"]
|
||||
default = ["std", "concurrent"]
|
||||
executable = ["dep:clap", "dep:rand-utils", "std"]
|
||||
smt_hashmaps = ["dep:hashbrown"]
|
||||
internal = []
|
||||
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
|
||||
std = [
|
||||
"blake3/std",
|
||||
"dep:cc",
|
||||
"rand/std",
|
||||
"rand/std_rng",
|
||||
"winter-crypto/std",
|
||||
"winter-math/std",
|
||||
"winter-utils/std",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.5", optional = true, features = ["derive"] }
|
||||
hashbrown = { version = "0.15", optional = true, features = ["serde"] }
|
||||
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
|
||||
num-complex = { version = "0.4", default-features = false }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
rand_core = { version = "0.6", default-features = false }
|
||||
rand-utils = { version = "0.12", package = "winter-rand-utils", optional = true }
|
||||
rayon = { version = "1.10", optional = true }
|
||||
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
|
||||
sha3 = { version = "0.10", default-features = false }
|
||||
thiserror = { version = "2.0", default-features = false }
|
||||
winter-crypto = { version = "0.12", default-features = false }
|
||||
winter-math = { version = "0.12", default-features = false }
|
||||
winter-utils = { version = "0.12", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_matches = { version = "1.5", default-features = false }
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
hex = { version = "0.4", default-features = false, features = ["alloc"] }
|
||||
proptest = "1.6"
|
||||
rand_chacha = { version = "0.3", default-features = false }
|
||||
rand-utils = { version = "0.12", package = "winter-rand-utils" }
|
||||
seq-macro = { version = "0.3" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = { version = "1.2", optional = true, features = ["parallel"] }
|
||||
glob = "0.3"
|
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 Polygon Miden
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
93
Makefile
Normal file
93
Makefile
Normal file
|
@ -0,0 +1,93 @@
|
|||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
# -- variables --------------------------------------------------------------------------------------
|
||||
|
||||
WARNINGS=RUSTDOCFLAGS="-D warnings"
|
||||
DEBUG_OVERFLOW_INFO=RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2"
|
||||
|
||||
# -- linting --------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: clippy
|
||||
clippy: ## Run Clippy with configs
|
||||
$(WARNINGS) cargo +nightly clippy --workspace --all-targets --all-features
|
||||
|
||||
|
||||
.PHONY: fix
|
||||
fix: ## Run Fix with configs
|
||||
cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features
|
||||
|
||||
|
||||
.PHONY: format
|
||||
format: ## Run Format using nightly toolchain
|
||||
cargo +nightly fmt --all
|
||||
|
||||
|
||||
.PHONY: format-check
|
||||
format-check: ## Run Format using nightly toolchain but only in check mode
|
||||
cargo +nightly fmt --all --check
|
||||
|
||||
|
||||
.PHONY: lint
|
||||
lint: format fix clippy ## Run all linting tasks at once (Clippy, fixing, formatting)
|
||||
|
||||
# --- docs ----------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: doc
|
||||
doc: ## Generate and check documentation
|
||||
$(WARNINGS) cargo doc --all-features --keep-going --release
|
||||
|
||||
# --- testing -------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: test-default
|
||||
test-default: ## Run tests with default features
|
||||
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
|
||||
|
||||
.PHONY: test-smt-hashmaps
|
||||
test-smt-hashmaps: ## Run tests with `smt_hashmaps` feature enabled
|
||||
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --features smt_hashmaps
|
||||
|
||||
.PHONY: test-no-std
|
||||
test-no-std: ## Run tests with `no-default-features` (std)
|
||||
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --no-default-features
|
||||
|
||||
|
||||
.PHONY: test
|
||||
test: test-default test-smt-hashmaps test-no-std ## Run all tests
|
||||
|
||||
# --- checking ------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: check
|
||||
check: ## Check all targets and features for errors without code generation
|
||||
cargo check --all-targets --all-features
|
||||
|
||||
# --- building ------------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: build
|
||||
build: ## Build with default features enabled
|
||||
cargo build --release
|
||||
|
||||
.PHONY: build-no-std
|
||||
build-no-std: ## Build without the standard library
|
||||
cargo build --release --no-default-features --target wasm32-unknown-unknown
|
||||
|
||||
.PHONY: build-avx2
|
||||
build-avx2: ## Build with avx2 support
|
||||
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
|
||||
|
||||
.PHONY: build-sve
|
||||
build-sve: ## Build with sve support
|
||||
RUSTFLAGS="-C target-feature=+sve" cargo build --release
|
||||
|
||||
# --- benchmarking --------------------------------------------------------------------------------
|
||||
|
||||
.PHONY: bench
|
||||
bench: ## Run crypto benchmarks
|
||||
cargo bench --features concurrent
|
||||
|
||||
.PHONY: bench-smt-concurrent
|
||||
bench-smt-concurrent: ## Run SMT benchmarks with concurrent feature
|
||||
cargo run --release --features concurrent,executable -- --size 1000000
|
110
README.md
Normal file
110
README.md
Normal file
|
@ -0,0 +1,110 @@
|
|||
# Miden Crypto
|
||||
|
||||
[](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE)
|
||||
[](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml)
|
||||
[](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml)
|
||||
[](https://www.rust-lang.org/tools/install)
|
||||
[](https://crates.io/crates/miden-crypto)
|
||||
|
||||
This crate contains cryptographic primitives used in Polygon Miden.
|
||||
|
||||
## Hash
|
||||
|
||||
[Hash module](./src/hash) provides a set of cryptographic hash functions which are used by the Miden VM and the Miden rollup. Currently, these functions are:
|
||||
|
||||
- [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
|
||||
- [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
|
||||
- [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
|
||||
|
||||
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
|
||||
|
||||
## Merkle
|
||||
|
||||
[Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are:
|
||||
|
||||
- `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
|
||||
- `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
|
||||
- `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
|
||||
- `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
|
||||
- `PartialMmr`: a partial view of a Merkle mountain range structure.
|
||||
- `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
|
||||
- `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
|
||||
|
||||
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
|
||||
|
||||
## Signatures
|
||||
|
||||
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
|
||||
|
||||
- `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the _hash-to-point_ algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
|
||||
|
||||
For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator.
|
||||
|
||||
## Pseudo-Random Element Generator
|
||||
|
||||
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
|
||||
|
||||
- `FeltRng`: a trait for generating random field elements and random 4 field elements.
|
||||
- `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPO hash function.
|
||||
- `RpxRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPX hash function.
|
||||
|
||||
## Make commands
|
||||
|
||||
We use `make` to automate building, testing, and other processes. In most cases, `make` commands are wrappers around `cargo` commands with specific arguments. You can view the list of available commands in the [Makefile](Makefile), or run the following command:
|
||||
|
||||
```shell
|
||||
make
|
||||
```
|
||||
|
||||
## Crate features
|
||||
|
||||
This crate can be compiled with the following features:
|
||||
|
||||
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
|
||||
- `std` - enabled by default and relies on the Rust standard library.
|
||||
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
|
||||
- `smt_hashmaps` - uses hashbrown hashmaps in SMT implementation which significantly improves performance of SMT updating. Keys ordering in SMT iterators is not guarantied when this feature is enabled.
|
||||
|
||||
All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
|
||||
|
||||
To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command:
|
||||
|
||||
```shell
|
||||
make build-no-std
|
||||
```
|
||||
|
||||
### AVX2 acceleration
|
||||
|
||||
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
|
||||
|
||||
```shell
|
||||
make build-avx2
|
||||
```
|
||||
|
||||
### SVE acceleration
|
||||
|
||||
On platforms with [SVE](<https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)>) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
|
||||
|
||||
```shell
|
||||
make build-sve
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The best way to test the library is using our [Makefile](Makefile), this will enable you to use our pre-defined optimized testing commands:
|
||||
|
||||
```shell
|
||||
make test
|
||||
```
|
||||
|
||||
For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
|
||||
|
||||
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile](Makefile)):
|
||||
|
||||
```shell
|
||||
RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project is [MIT licensed](./LICENSE).
|
55
arch/arm64-sve/rpo/library.c
Normal file
55
arch/arm64-sve/rpo/library.c
Normal file
|
@ -0,0 +1,55 @@
|
|||
#include <stddef.h>
|
||||
#include <arm_sve.h>
|
||||
#include "library.h"
|
||||
#include "rpo_hash_128bit.h"
|
||||
#include "rpo_hash_256bit.h"
|
||||
|
||||
// The STATE_WIDTH of RPO hash is 12x u64 elements.
|
||||
// The current generation of SVE-enabled processors - Neoverse V1
|
||||
// (e.g. AWS Graviton3) have 256-bit vector registers (4x u64)
|
||||
// This allows us to split the state into 3 vectors of 4 elements
|
||||
// and process all 3 independent of each other.
|
||||
|
||||
// We see the biggest performance gains by leveraging both
|
||||
// vector and scalar operations on parts of the state array.
|
||||
// Due to high latency of vector operations, the processor is able
|
||||
// to reorder and pipeline scalar instructions while we wait for
|
||||
// vector results. This effectively gives us some 'free' scalar
|
||||
// operations and masks vector latency.
|
||||
//
|
||||
// This also means that we can fully saturate all four arithmetic
|
||||
// units of the processor (2x scalar, 2x SIMD)
|
||||
//
|
||||
// THIS ANALYSIS NEEDS TO BE PERFORMED AGAIN ONCE PROCESSORS
|
||||
// GAIN WIDER REGISTERS. It's quite possible that with 8x u64
|
||||
// vectors processing 2 partially filled vectors might
|
||||
// be easier and faster than dealing with scalar operations
|
||||
// on the remainder of the array.
|
||||
//
|
||||
// FOR NOW THIS IS ONLY ENABLED ON 4x u64 VECTORS! It falls back
|
||||
// to the regular, already highly-optimized scalar version
|
||||
// if the conditions are not met.
|
||||
|
||||
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector
|
||||
|
||||
if (vl == 2) {
|
||||
return add_constants_and_apply_sbox_128(state, constants);
|
||||
} else if (vl == 4) {
|
||||
return add_constants_and_apply_sbox_256(state, constants);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector
|
||||
|
||||
if (vl == 2) {
|
||||
return add_constants_and_apply_inv_sbox_128(state, constants);
|
||||
} else if (vl == 4) {
|
||||
return add_constants_and_apply_inv_sbox_256(state, constants);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
12
arch/arm64-sve/rpo/library.h
Normal file
12
arch/arm64-sve/rpo/library.h
Normal file
|
@ -0,0 +1,12 @@
|
|||
#ifndef CRYPTO_LIBRARY_H
|
||||
#define CRYPTO_LIBRARY_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
|
||||
bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]);
|
||||
bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]);
|
||||
|
||||
#endif //CRYPTO_LIBRARY_H
|
318
arch/arm64-sve/rpo/rpo_hash_128bit.h
Normal file
318
arch/arm64-sve/rpo/rpo_hash_128bit.h
Normal file
|
@ -0,0 +1,318 @@
|
|||
#ifndef RPO_SVE_RPO_HASH_128_H
|
||||
#define RPO_SVE_RPO_HASH_128_H
|
||||
|
||||
#include <arm_sve.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
|
||||
#define COPY_128(NAME, VIN1, VIN2, VIN3, VIN4, SIN) \
|
||||
svuint64_t NAME ## _1 = VIN1; \
|
||||
svuint64_t NAME ## _2 = VIN2; \
|
||||
svuint64_t NAME ## _3 = VIN3; \
|
||||
svuint64_t NAME ## _4 = VIN4; \
|
||||
uint64_t NAME ## _tail[4]; \
|
||||
memcpy(NAME ## _tail, SIN, 4 * sizeof(uint64_t))
|
||||
|
||||
#define MULTIPLY_128(PRED, DEST, OP) \
|
||||
mul_128(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, &DEST ## _3, &OP ## _3, &DEST ## _4, &OP ## _4, DEST ## _tail, OP ## _tail)
|
||||
|
||||
#define SQUARE_128(PRED, NAME) \
|
||||
sq_128(PRED, &NAME ## _1, &NAME ## _2, &NAME ## _3, &NAME ## _4, NAME ## _tail)
|
||||
|
||||
#define SQUARE_DEST_128(PRED, DEST, SRC) \
|
||||
COPY_128(DEST, SRC ## _1, SRC ## _2, SRC ## _3, SRC ## _4, SRC ## _tail); \
|
||||
SQUARE_128(PRED, DEST);
|
||||
|
||||
#define POW_ACC_128(PRED, NAME, CNT, TAIL) \
|
||||
for (size_t i = 0; i < CNT; i++) { \
|
||||
SQUARE_128(PRED, NAME); \
|
||||
} \
|
||||
MULTIPLY_128(PRED, NAME, TAIL);
|
||||
|
||||
#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \
|
||||
COPY_128(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3, HEAD ## _4, HEAD ## _tail); \
|
||||
POW_ACC_128(PRED, DEST, CNT, TAIL)
|
||||
|
||||
extern inline void add_constants_128(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *const1,
|
||||
svuint64_t *state2,
|
||||
svuint64_t *const2,
|
||||
svuint64_t *state3,
|
||||
svuint64_t *const3,
|
||||
svuint64_t *state4,
|
||||
svuint64_t *const4,
|
||||
|
||||
uint64_t *state_tail,
|
||||
uint64_t *const_tail
|
||||
) {
|
||||
uint64_t Ms = 0xFFFFFFFF00000001ull;
|
||||
svuint64_t Mv = svindex_u64(Ms, 0);
|
||||
|
||||
uint64_t p_1 = Ms - const_tail[0];
|
||||
uint64_t p_2 = Ms - const_tail[1];
|
||||
uint64_t p_3 = Ms - const_tail[2];
|
||||
uint64_t p_4 = Ms - const_tail[3];
|
||||
|
||||
uint64_t x_1, x_2, x_3, x_4;
|
||||
uint32_t adj_1 = -__builtin_sub_overflow(state_tail[0], p_1, &x_1);
|
||||
uint32_t adj_2 = -__builtin_sub_overflow(state_tail[1], p_2, &x_2);
|
||||
uint32_t adj_3 = -__builtin_sub_overflow(state_tail[2], p_3, &x_3);
|
||||
uint32_t adj_4 = -__builtin_sub_overflow(state_tail[3], p_4, &x_4);
|
||||
|
||||
state_tail[0] = x_1 - (uint64_t)adj_1;
|
||||
state_tail[1] = x_2 - (uint64_t)adj_2;
|
||||
state_tail[2] = x_3 - (uint64_t)adj_3;
|
||||
state_tail[3] = x_4 - (uint64_t)adj_4;
|
||||
|
||||
svuint64_t p1 = svsub_x(pg, Mv, *const1);
|
||||
svuint64_t p2 = svsub_x(pg, Mv, *const2);
|
||||
svuint64_t p3 = svsub_x(pg, Mv, *const3);
|
||||
svuint64_t p4 = svsub_x(pg, Mv, *const4);
|
||||
|
||||
svuint64_t x1 = svsub_x(pg, *state1, p1);
|
||||
svuint64_t x2 = svsub_x(pg, *state2, p2);
|
||||
svuint64_t x3 = svsub_x(pg, *state3, p3);
|
||||
svuint64_t x4 = svsub_x(pg, *state4, p4);
|
||||
|
||||
svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
|
||||
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
|
||||
svbool_t pt3 = svcmplt_u64(pg, *state3, p3);
|
||||
svbool_t pt4 = svcmplt_u64(pg, *state4, p4);
|
||||
|
||||
*state1 = svsub_m(pt1, x1, (uint32_t)-1);
|
||||
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
|
||||
*state3 = svsub_m(pt3, x3, (uint32_t)-1);
|
||||
*state4 = svsub_m(pt4, x4, (uint32_t)-1);
|
||||
}
|
||||
|
||||
extern inline void mul_128(
|
||||
svbool_t pg,
|
||||
svuint64_t *r1,
|
||||
const svuint64_t *op1,
|
||||
svuint64_t *r2,
|
||||
const svuint64_t *op2,
|
||||
svuint64_t *r3,
|
||||
const svuint64_t *op3,
|
||||
svuint64_t *r4,
|
||||
const svuint64_t *op4,
|
||||
uint64_t *r_tail,
|
||||
const uint64_t *op_tail
|
||||
) {
|
||||
__uint128_t x_1 = r_tail[0];
|
||||
__uint128_t x_2 = r_tail[1];
|
||||
__uint128_t x_3 = r_tail[2];
|
||||
__uint128_t x_4 = r_tail[3];
|
||||
|
||||
x_1 *= (__uint128_t) op_tail[0];
|
||||
x_2 *= (__uint128_t) op_tail[1];
|
||||
x_3 *= (__uint128_t) op_tail[2];
|
||||
x_4 *= (__uint128_t) op_tail[3];
|
||||
|
||||
uint64_t x0_1 = x_1;
|
||||
uint64_t x0_2 = x_2;
|
||||
uint64_t x0_3 = x_3;
|
||||
uint64_t x0_4 = x_4;
|
||||
|
||||
svuint64_t l1 = svmul_x(pg, *r1, *op1);
|
||||
svuint64_t l2 = svmul_x(pg, *r2, *op2);
|
||||
svuint64_t l3 = svmul_x(pg, *r3, *op3);
|
||||
svuint64_t l4 = svmul_x(pg, *r4, *op4);
|
||||
|
||||
uint64_t x1_1 = (x_1 >> 64);
|
||||
uint64_t x1_2 = (x_2 >> 64);
|
||||
uint64_t x1_3 = (x_3 >> 64);
|
||||
uint64_t x1_4 = (x_4 >> 64);
|
||||
|
||||
uint64_t a_1, a_2, a_3, a_4;
|
||||
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
|
||||
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
|
||||
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
|
||||
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);
|
||||
|
||||
svuint64_t ls1 = svlsl_x(pg, l1, 32);
|
||||
svuint64_t ls2 = svlsl_x(pg, l2, 32);
|
||||
svuint64_t ls3 = svlsl_x(pg, l3, 32);
|
||||
svuint64_t ls4 = svlsl_x(pg, l4, 32);
|
||||
|
||||
svuint64_t a1 = svadd_x(pg, l1, ls1);
|
||||
svuint64_t a2 = svadd_x(pg, l2, ls2);
|
||||
svuint64_t a3 = svadd_x(pg, l3, ls3);
|
||||
svuint64_t a4 = svadd_x(pg, l4, ls4);
|
||||
|
||||
svbool_t e1 = svcmplt(pg, a1, l1);
|
||||
svbool_t e2 = svcmplt(pg, a2, l2);
|
||||
svbool_t e3 = svcmplt(pg, a3, l3);
|
||||
svbool_t e4 = svcmplt(pg, a4, l4);
|
||||
|
||||
svuint64_t as1 = svlsr_x(pg, a1, 32);
|
||||
svuint64_t as2 = svlsr_x(pg, a2, 32);
|
||||
svuint64_t as3 = svlsr_x(pg, a3, 32);
|
||||
svuint64_t as4 = svlsr_x(pg, a4, 32);
|
||||
|
||||
svuint64_t b1 = svsub_x(pg, a1, as1);
|
||||
svuint64_t b2 = svsub_x(pg, a2, as2);
|
||||
svuint64_t b3 = svsub_x(pg, a3, as3);
|
||||
svuint64_t b4 = svsub_x(pg, a4, as4);
|
||||
|
||||
b1 = svsub_m(e1, b1, 1);
|
||||
b2 = svsub_m(e2, b2, 1);
|
||||
b3 = svsub_m(e3, b3, 1);
|
||||
b4 = svsub_m(e4, b4, 1);
|
||||
|
||||
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
|
||||
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
|
||||
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
|
||||
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;
|
||||
|
||||
uint64_t r_1, r_2, r_3, r_4;
|
||||
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
|
||||
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
|
||||
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
|
||||
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);
|
||||
|
||||
svuint64_t h1 = svmulh_x(pg, *r1, *op1);
|
||||
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
|
||||
svuint64_t h3 = svmulh_x(pg, *r3, *op3);
|
||||
svuint64_t h4 = svmulh_x(pg, *r4, *op4);
|
||||
|
||||
svuint64_t tr1 = svsub_x(pg, h1, b1);
|
||||
svuint64_t tr2 = svsub_x(pg, h2, b2);
|
||||
svuint64_t tr3 = svsub_x(pg, h3, b3);
|
||||
svuint64_t tr4 = svsub_x(pg, h4, b4);
|
||||
|
||||
svbool_t c1 = svcmplt_u64(pg, h1, b1);
|
||||
svbool_t c2 = svcmplt_u64(pg, h2, b2);
|
||||
svbool_t c3 = svcmplt_u64(pg, h3, b3);
|
||||
svbool_t c4 = svcmplt_u64(pg, h4, b4);
|
||||
|
||||
*r1 = svsub_m(c1, tr1, (uint32_t) -1);
|
||||
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
|
||||
*r3 = svsub_m(c3, tr3, (uint32_t) -1);
|
||||
*r4 = svsub_m(c4, tr4, (uint32_t) -1);
|
||||
|
||||
uint32_t minus1_1 = 0 - c_1;
|
||||
uint32_t minus1_2 = 0 - c_2;
|
||||
uint32_t minus1_3 = 0 - c_3;
|
||||
uint32_t minus1_4 = 0 - c_4;
|
||||
|
||||
r_tail[0] = r_1 - (uint64_t)minus1_1;
|
||||
r_tail[1] = r_2 - (uint64_t)minus1_2;
|
||||
r_tail[2] = r_3 - (uint64_t)minus1_3;
|
||||
r_tail[3] = r_4 - (uint64_t)minus1_4;
|
||||
}
|
||||
|
||||
extern inline void sq_128(svbool_t pg, svuint64_t *a, svuint64_t *b, svuint64_t *c, svuint64_t *d, uint64_t *e) {
|
||||
mul_128(pg, a, a, b, b, c, c, d, d, e, e);
|
||||
}
|
||||
|
||||
extern inline void apply_sbox_128(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *state2,
|
||||
svuint64_t *state3,
|
||||
svuint64_t *state4,
|
||||
uint64_t *state_tail
|
||||
) {
|
||||
COPY_128(x, *state1, *state2, *state3, *state4, state_tail); // copy input to x
|
||||
SQUARE_128(pg, x); // x contains input^2
|
||||
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^3
|
||||
SQUARE_128(pg, x); // x contains input^4
|
||||
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^7
|
||||
}
|
||||
|
||||
extern inline void apply_inv_sbox_128(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *state2,
|
||||
svuint64_t *state3,
|
||||
svuint64_t *state4,
|
||||
uint64_t *state_tail
|
||||
) {
|
||||
// base^10
|
||||
COPY_128(t1, *state1, *state2, *state3, *state4, state_tail);
|
||||
SQUARE_128(pg, t1);
|
||||
|
||||
// base^100
|
||||
SQUARE_DEST_128(pg, t2, t1);
|
||||
|
||||
// base^100100
|
||||
POW_ACC_DEST(pg, t3, 3, t2, t2);
|
||||
|
||||
// base^100100100100
|
||||
POW_ACC_DEST(pg, t4, 6, t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
POW_ACC_DEST(pg, t5, 12, t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
POW_ACC_DEST(pg, t6, 6, t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
POW_ACC_DEST(pg, t7, 31, t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
SQUARE_128(pg, t7);
|
||||
MULTIPLY_128(pg, t7, t6);
|
||||
SQUARE_128(pg, t7);
|
||||
SQUARE_128(pg, t7);
|
||||
MULTIPLY_128(pg, t7, t1);
|
||||
MULTIPLY_128(pg, t7, t2);
|
||||
mul_128(pg, state1, &t7_1, state2, &t7_2, state3, &t7_3, state4, &t7_4, state_tail, t7_tail);
|
||||
}
|
||||
|
||||
bool add_constants_and_apply_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
|
||||
svbool_t ptrue = svptrue_b64();
|
||||
|
||||
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
|
||||
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
|
||||
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
|
||||
svuint64_t state4 = svld1(ptrue, state + 3 * vl);
|
||||
|
||||
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
|
||||
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
|
||||
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
|
||||
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);
|
||||
|
||||
add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
|
||||
apply_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);
|
||||
|
||||
svst1(ptrue, state + 0 * vl, state1);
|
||||
svst1(ptrue, state + 1 * vl, state2);
|
||||
svst1(ptrue, state + 2 * vl, state3);
|
||||
svst1(ptrue, state + 3 * vl, state4);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool add_constants_and_apply_inv_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
|
||||
svbool_t ptrue = svptrue_b64();
|
||||
|
||||
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
|
||||
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
|
||||
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
|
||||
svuint64_t state4 = svld1(ptrue, state + 3 * vl);
|
||||
|
||||
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
|
||||
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
|
||||
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
|
||||
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);
|
||||
|
||||
add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
|
||||
apply_inv_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);
|
||||
|
||||
svst1(ptrue, state + 0 * vl, state1);
|
||||
svst1(ptrue, state + 1 * vl, state2);
|
||||
svst1(ptrue, state + 2 * vl, state3);
|
||||
svst1(ptrue, state + 3 * vl, state4);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif //RPO_SVE_RPO_HASH_128_H
|
261
arch/arm64-sve/rpo/rpo_hash_256bit.h
Normal file
261
arch/arm64-sve/rpo/rpo_hash_256bit.h
Normal file
|
@ -0,0 +1,261 @@
|
|||
#ifndef RPO_SVE_RPO_HASH_256_H
|
||||
#define RPO_SVE_RPO_HASH_256_H
|
||||
|
||||
#include <arm_sve.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#define STATE_WIDTH 12
|
||||
|
||||
#define COPY_256(NAME, VIN1, VIN2, SIN3) \
|
||||
svuint64_t NAME ## _1 = VIN1; \
|
||||
svuint64_t NAME ## _2 = VIN2; \
|
||||
uint64_t NAME ## _3[4]; \
|
||||
memcpy(NAME ## _3, SIN3, 4 * sizeof(uint64_t))
|
||||
|
||||
#define MULTIPLY_256(PRED, DEST, OP) \
|
||||
mul_256(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, DEST ## _3, OP ## _3)
|
||||
|
||||
#define SQUARE_256(PRED, NAME) \
|
||||
sq_256(PRED, &NAME ## _1, &NAME ## _2, NAME ## _3)
|
||||
|
||||
#define SQUARE_DEST_256(PRED, DEST, SRC) \
|
||||
COPY_256(DEST, SRC ## _1, SRC ## _2, SRC ## _3); \
|
||||
SQUARE_256(PRED, DEST);
|
||||
|
||||
#define POW_ACC(PRED, NAME, CNT, TAIL) \
|
||||
for (size_t i = 0; i < CNT; i++) { \
|
||||
SQUARE_256(PRED, NAME); \
|
||||
} \
|
||||
MULTIPLY_256(PRED, NAME, TAIL);
|
||||
|
||||
#define POW_ACC_DEST_256(PRED, DEST, CNT, HEAD, TAIL) \
|
||||
COPY_256(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3); \
|
||||
POW_ACC(PRED, DEST, CNT, TAIL)
|
||||
|
||||
extern inline void add_constants_256(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *const1,
|
||||
svuint64_t *state2,
|
||||
svuint64_t *const2,
|
||||
uint64_t *state3,
|
||||
uint64_t *const3
|
||||
) {
|
||||
uint64_t Ms = 0xFFFFFFFF00000001ull;
|
||||
svuint64_t Mv = svindex_u64(Ms, 0);
|
||||
|
||||
uint64_t p_1 = Ms - const3[0];
|
||||
uint64_t p_2 = Ms - const3[1];
|
||||
uint64_t p_3 = Ms - const3[2];
|
||||
uint64_t p_4 = Ms - const3[3];
|
||||
|
||||
uint64_t x_1, x_2, x_3, x_4;
|
||||
uint32_t adj_1 = -__builtin_sub_overflow(state3[0], p_1, &x_1);
|
||||
uint32_t adj_2 = -__builtin_sub_overflow(state3[1], p_2, &x_2);
|
||||
uint32_t adj_3 = -__builtin_sub_overflow(state3[2], p_3, &x_3);
|
||||
uint32_t adj_4 = -__builtin_sub_overflow(state3[3], p_4, &x_4);
|
||||
|
||||
state3[0] = x_1 - (uint64_t)adj_1;
|
||||
state3[1] = x_2 - (uint64_t)adj_2;
|
||||
state3[2] = x_3 - (uint64_t)adj_3;
|
||||
state3[3] = x_4 - (uint64_t)adj_4;
|
||||
|
||||
svuint64_t p1 = svsub_x(pg, Mv, *const1);
|
||||
svuint64_t p2 = svsub_x(pg, Mv, *const2);
|
||||
|
||||
svuint64_t x1 = svsub_x(pg, *state1, p1);
|
||||
svuint64_t x2 = svsub_x(pg, *state2, p2);
|
||||
|
||||
svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
|
||||
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
|
||||
|
||||
*state1 = svsub_m(pt1, x1, (uint32_t)-1);
|
||||
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
|
||||
}
|
||||
|
||||
extern inline void mul_256(
|
||||
svbool_t pg,
|
||||
svuint64_t *r1,
|
||||
const svuint64_t *op1,
|
||||
svuint64_t *r2,
|
||||
const svuint64_t *op2,
|
||||
uint64_t *r3,
|
||||
const uint64_t *op3
|
||||
) {
|
||||
__uint128_t x_1 = r3[0];
|
||||
__uint128_t x_2 = r3[1];
|
||||
__uint128_t x_3 = r3[2];
|
||||
__uint128_t x_4 = r3[3];
|
||||
|
||||
x_1 *= (__uint128_t) op3[0];
|
||||
x_2 *= (__uint128_t) op3[1];
|
||||
x_3 *= (__uint128_t) op3[2];
|
||||
x_4 *= (__uint128_t) op3[3];
|
||||
|
||||
uint64_t x0_1 = x_1;
|
||||
uint64_t x0_2 = x_2;
|
||||
uint64_t x0_3 = x_3;
|
||||
uint64_t x0_4 = x_4;
|
||||
|
||||
svuint64_t l1 = svmul_x(pg, *r1, *op1);
|
||||
svuint64_t l2 = svmul_x(pg, *r2, *op2);
|
||||
|
||||
uint64_t x1_1 = (x_1 >> 64);
|
||||
uint64_t x1_2 = (x_2 >> 64);
|
||||
uint64_t x1_3 = (x_3 >> 64);
|
||||
uint64_t x1_4 = (x_4 >> 64);
|
||||
|
||||
uint64_t a_1, a_2, a_3, a_4;
|
||||
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
|
||||
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
|
||||
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
|
||||
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);
|
||||
|
||||
svuint64_t ls1 = svlsl_x(pg, l1, 32);
|
||||
svuint64_t ls2 = svlsl_x(pg, l2, 32);
|
||||
|
||||
svuint64_t a1 = svadd_x(pg, l1, ls1);
|
||||
svuint64_t a2 = svadd_x(pg, l2, ls2);
|
||||
|
||||
svbool_t e1 = svcmplt(pg, a1, l1);
|
||||
svbool_t e2 = svcmplt(pg, a2, l2);
|
||||
|
||||
svuint64_t as1 = svlsr_x(pg, a1, 32);
|
||||
svuint64_t as2 = svlsr_x(pg, a2, 32);
|
||||
|
||||
svuint64_t b1 = svsub_x(pg, a1, as1);
|
||||
svuint64_t b2 = svsub_x(pg, a2, as2);
|
||||
|
||||
b1 = svsub_m(e1, b1, 1);
|
||||
b2 = svsub_m(e2, b2, 1);
|
||||
|
||||
uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
|
||||
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
|
||||
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
|
||||
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;
|
||||
|
||||
uint64_t r_1, r_2, r_3, r_4;
|
||||
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
|
||||
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
|
||||
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
|
||||
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);
|
||||
|
||||
svuint64_t h1 = svmulh_x(pg, *r1, *op1);
|
||||
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
|
||||
|
||||
svuint64_t tr1 = svsub_x(pg, h1, b1);
|
||||
svuint64_t tr2 = svsub_x(pg, h2, b2);
|
||||
|
||||
svbool_t c1 = svcmplt_u64(pg, h1, b1);
|
||||
svbool_t c2 = svcmplt_u64(pg, h2, b2);
|
||||
|
||||
*r1 = svsub_m(c1, tr1, (uint32_t) -1);
|
||||
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
|
||||
|
||||
uint32_t minus1_1 = 0 - c_1;
|
||||
uint32_t minus1_2 = 0 - c_2;
|
||||
uint32_t minus1_3 = 0 - c_3;
|
||||
uint32_t minus1_4 = 0 - c_4;
|
||||
|
||||
r3[0] = r_1 - (uint64_t)minus1_1;
|
||||
r3[1] = r_2 - (uint64_t)minus1_2;
|
||||
r3[2] = r_3 - (uint64_t)minus1_3;
|
||||
r3[3] = r_4 - (uint64_t)minus1_4;
|
||||
}
|
||||
|
||||
extern inline void sq_256(svbool_t pg, svuint64_t *a, svuint64_t *b, uint64_t *c) {
|
||||
mul_256(pg, a, a, b, b, c, c);
|
||||
}
|
||||
|
||||
extern inline void apply_sbox_256(
|
||||
svbool_t pg,
|
||||
svuint64_t *state1,
|
||||
svuint64_t *state2,
|
||||
uint64_t *state3
|
||||
) {
|
||||
COPY_256(x, *state1, *state2, state3); // copy input to x
|
||||
SQUARE_256(pg, x); // x contains input^2
|
||||
mul_256(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^3
|
||||
SQUARE_256(pg, x); // x contains input^4
|
||||
mul_256(pg, state1, &x_1, state2, &x_2, state3, x_3); // state contains input^7
|
||||
}
|
||||
|
||||
extern inline void apply_inv_sbox_256(
|
||||
svbool_t pg,
|
||||
svuint64_t *state_1,
|
||||
svuint64_t *state_2,
|
||||
uint64_t *state_3
|
||||
) {
|
||||
// base^10
|
||||
COPY_256(t1, *state_1, *state_2, state_3);
|
||||
SQUARE_256(pg, t1);
|
||||
|
||||
// base^100
|
||||
SQUARE_DEST_256(pg, t2, t1);
|
||||
|
||||
// base^100100
|
||||
POW_ACC_DEST_256(pg, t3, 3, t2, t2);
|
||||
|
||||
// base^100100100100
|
||||
POW_ACC_DEST_256(pg, t4, 6, t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
POW_ACC_DEST_256(pg, t5, 12, t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
POW_ACC_DEST_256(pg, t6, 6, t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
POW_ACC_DEST_256(pg, t7, 31, t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
SQUARE_256(pg, t7);
|
||||
MULTIPLY_256(pg, t7, t6);
|
||||
SQUARE_256(pg, t7);
|
||||
SQUARE_256(pg, t7);
|
||||
MULTIPLY_256(pg, t7, t1);
|
||||
MULTIPLY_256(pg, t7, t2);
|
||||
mul_256(pg, state_1, &t7_1, state_2, &t7_2, state_3, t7_3);
|
||||
}
|
||||
|
||||
bool add_constants_and_apply_sbox_256(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = 4; // number of u64 numbers in one 128 bit SVE vector
|
||||
svbool_t ptrue = svptrue_b64();
|
||||
|
||||
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
|
||||
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
|
||||
|
||||
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
|
||||
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
|
||||
|
||||
add_constants_256(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
|
||||
apply_sbox_256(ptrue, &state1, &state2, state + 8);
|
||||
|
||||
svst1(ptrue, state + 0 * vl, state1);
|
||||
svst1(ptrue, state + 1 * vl, state2);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool add_constants_and_apply_inv_sbox_256(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
|
||||
const uint64_t vl = 4; // number of u64 numbers in one 128 bit SVE vector
|
||||
svbool_t ptrue = svptrue_b64();
|
||||
|
||||
svuint64_t state1 = svld1(ptrue, state + 0 * vl);
|
||||
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
|
||||
|
||||
svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
|
||||
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
|
||||
|
||||
add_constants_256(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
|
||||
apply_inv_sbox_256(ptrue, &state1, &state2, state + 8);
|
||||
|
||||
svst1(ptrue, state + 0 * vl, state1);
|
||||
svst1(ptrue, state + 1 * vl, state2);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif //RPO_SVE_RPO_HASH_256_H
|
58
benches/README.md
Normal file
58
benches/README.md
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Miden VM Hash Functions
|
||||
In the Miden VM, we make use of different hash functions. Some of these are "traditional" hash functions, like `BLAKE3`, which are optimized for out-of-STARK performance, while others are algebraic hash functions, like `Rescue Prime`, and are more optimized for a better performance inside the STARK. In what follows, we benchmark several such hash functions and compare against other constructions that are used by other proving systems. More precisely, we benchmark:
|
||||
|
||||
* **BLAKE3** as specified [here](https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf) and implemented [here](https://github.com/BLAKE3-team/BLAKE3) (with a wrapper exposed via this crate).
|
||||
* **SHA3** as specified [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/sha/mod.rs).
|
||||
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
|
||||
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
|
||||
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
|
||||
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
|
||||
|
||||
## Comparison and Instructions
|
||||
|
||||
### Comparison
|
||||
We benchmark the above hash functions using two scenarios. The first is a 2-to-1 $(a,b)\mapsto h(a,b)$ hashing where both $a$, $b$ and $h(a,b)$ are the digests corresponding to each of the hash functions.
|
||||
The second scenario is that of sequential hashing where we take a sequence of length $100$ field elements and hash these to produce a single digest. The digests are $4$ field elements in a prime field with modulus $2^{64} - 2^{32} + 1$ (i.e., 32 bytes) for Poseidon, Rescue Prime and RPO, and an array `[u8; 32]` for SHA3 and BLAKE3.
|
||||
|
||||
#### Scenario 1: 2-to-1 hashing `h(a,b)`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
|
||||
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
|
||||
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
|
||||
| Amazon Graviton 4 | 96 ns | | | | 5.1 µs | 2.8 µs |
|
||||
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
|
||||
| AMD EPYC 9R14 | 83 ns | | | | 4.3 µs | 2.4 µs |
|
||||
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
|
||||
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
|
||||
|
||||
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
|
||||
|
||||
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
|
||||
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
|
||||
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
|
||||
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
|
||||
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
|
||||
| Amazon Graviton 4 | 1.2 µs | | | | 67 µs | 36 µs |
|
||||
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
|
||||
| AMD EPYC 9R14 | 0.9 µs | | | | 56 µs | 32 µs |
|
||||
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
|
||||
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
|
||||
|
||||
Notes:
|
||||
- On Graviton 3 and 4, RPO256 and RPX256 are run with SVE acceleration enabled.
|
||||
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
|
||||
|
||||
### Instructions
|
||||
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
|
||||
|
||||
```
|
||||
cargo bench hash
|
||||
```
|
||||
|
||||
To run the benchmarks for Rescue Prime, Poseidon and SHA3, clone the following [repository](https://github.com/Dominik1999/winterfell.git) as above, then checkout the `hash-functions-benches` branch, and from the root directory of the repo run the following:
|
||||
|
||||
```
|
||||
cargo bench hash
|
||||
```
|
161
benches/hash.rs
Normal file
161
benches/hash.rs
Normal file
|
@ -0,0 +1,161 @@
|
|||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use miden_crypto::{
|
||||
hash::{
|
||||
blake::Blake3_256,
|
||||
rpo::{Rpo256, RpoDigest},
|
||||
rpx::{Rpx256, RpxDigest},
|
||||
},
|
||||
Felt,
|
||||
};
|
||||
use rand_utils::rand_value;
|
||||
use winter_crypto::Hasher;
|
||||
|
||||
fn rpo256_2to1(c: &mut Criterion) {
|
||||
let v: [RpoDigest; 2] = [Rpo256::hash(&[1_u8]), Rpo256::hash(&[2_u8])];
|
||||
c.bench_function("RPO256 2-to-1 hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpo256::merge(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPO256 2-to-1 hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
[
|
||||
Rpo256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
Rpo256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
]
|
||||
},
|
||||
|state| Rpo256::merge(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn rpo256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
c.bench_function("RPO256 sequential hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpo256::hash_elements(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPO256 sequential hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
v
|
||||
},
|
||||
|state| Rpo256::hash_elements(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_2to1(c: &mut Criterion) {
|
||||
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
|
||||
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::merge(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
[
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
]
|
||||
},
|
||||
|state| Rpx256::merge(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn rpx256_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
|
||||
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("RPX256 sequential hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
v
|
||||
},
|
||||
|state| Rpx256::hash_elements(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn blake3_2to1(c: &mut Criterion) {
|
||||
let v: [<Blake3_256 as Hasher>::Digest; 2] =
|
||||
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
|
||||
c.bench_function("Blake3 2-to-1 hashing (cached)", |bench| {
|
||||
bench.iter(|| Blake3_256::merge(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("Blake3 2-to-1 hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
[
|
||||
Blake3_256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
Blake3_256::hash(&rand_value::<u64>().to_le_bytes()),
|
||||
]
|
||||
},
|
||||
|state| Blake3_256::merge(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn blake3_sequential(c: &mut Criterion) {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(Felt::new)
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
c.bench_function("Blake3 sequential hashing (cached)", |bench| {
|
||||
bench.iter(|| Blake3_256::hash_elements(black_box(&v)))
|
||||
});
|
||||
|
||||
c.bench_function("Blake3 sequential hashing (random)", |bench| {
|
||||
bench.iter_batched(
|
||||
|| {
|
||||
let v: [Felt; 100] = (0..100)
|
||||
.map(|_| Felt::new(rand_value()))
|
||||
.collect::<Vec<Felt>>()
|
||||
.try_into()
|
||||
.expect("should not fail");
|
||||
v
|
||||
},
|
||||
|state| Blake3_256::hash_elements(&state),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
hash_group,
|
||||
rpx256_2to1,
|
||||
rpx256_sequential,
|
||||
rpo256_2to1,
|
||||
rpo256_sequential,
|
||||
blake3_2to1,
|
||||
blake3_sequential
|
||||
);
|
||||
criterion_main!(hash_group);
|
66
benches/merkle.rs
Normal file
66
benches/merkle.rs
Normal file
|
@ -0,0 +1,66 @@
|
|||
//! Benchmark for building a [`miden_crypto::merkle::MerkleTree`]. This is intended to be compared
|
||||
//! with the results from `benches/smt-subtree.rs`, as building a fully balanced Merkle tree with
|
||||
//! 256 leaves should indicate the *absolute best* performance we could *possibly* get for building
|
||||
//! a depth-8 sparse Merkle subtree, though practically speaking building a fully balanced Merkle
|
||||
//! tree will perform better than the sparse version. At the time of this writing (2024/11/24), this
|
||||
//! benchmark is about four times more efficient than the equivalent benchmark in
|
||||
//! `benches/smt-subtree.rs`.
|
||||
use std::{hint, mem, time::Duration};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE};
|
||||
use rand_utils::prng_array;
|
||||
|
||||
fn balanced_merkle_even(c: &mut Criterion) {
|
||||
c.bench_function("balanced-merkle-even", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let entries: Vec<Word> =
|
||||
(0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect();
|
||||
assert_eq!(entries.len(), 256);
|
||||
entries
|
||||
},
|
||||
|leaves| {
|
||||
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
|
||||
assert_eq!(tree.depth(), 8);
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn balanced_merkle_rand(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
c.bench_function("balanced-merkle-rand", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
let entries: Vec<Word> = (0..256).map(|_| generate_word(&mut seed)).collect();
|
||||
assert_eq!(entries.len(), 256);
|
||||
entries
|
||||
},
|
||||
|leaves| {
|
||||
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
|
||||
assert_eq!(tree.depth(), 8);
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = smt_subtree_group;
|
||||
config = Criterion::default()
|
||||
.measurement_time(Duration::from_secs(20))
|
||||
.configure_from_args();
|
||||
targets = balanced_merkle_even, balanced_merkle_rand
|
||||
}
|
||||
criterion_main!(smt_subtree_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
143
benches/smt-subtree.rs
Normal file
143
benches/smt-subtree.rs
Normal file
|
@ -0,0 +1,143 @@
|
|||
use std::{fmt::Debug, hint, mem, time::Duration};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{build_subtree_for_bench, NodeIndex, SmtLeaf, SubtreeLeaf, SMT_DEPTH},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
use rand_utils::prng_array;
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256];
|
||||
|
||||
fn smt_subtree_even(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
let mut group = c.benchmark_group("subtree8-even");
|
||||
|
||||
for pair_count in PAIR_COUNTS {
|
||||
let bench_id = BenchmarkId::from_parameter(pair_count);
|
||||
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Setup.
|
||||
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
|
||||
.map(|n| {
|
||||
// A single depth-8 subtree can have a maximum of 255 leaves.
|
||||
let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
|
||||
let key = RpoDigest::new([
|
||||
generate_value(&mut seed),
|
||||
ONE,
|
||||
Felt::new(n),
|
||||
Felt::new(leaf_index),
|
||||
]);
|
||||
let value = generate_word(&mut seed);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut leaves: Vec<_> = entries
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let leaf = SmtLeaf::new_single(*key, *value);
|
||||
let col = NodeIndex::from(leaf.index()).value();
|
||||
let hash = leaf.hash();
|
||||
SubtreeLeaf { col, hash }
|
||||
})
|
||||
.collect();
|
||||
leaves.sort();
|
||||
leaves.dedup_by_key(|leaf| leaf.col);
|
||||
leaves
|
||||
},
|
||||
|leaves| {
|
||||
// Benchmarked function.
|
||||
let (subtree, _) = build_subtree_for_bench(
|
||||
hint::black_box(leaves),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
);
|
||||
assert!(!subtree.is_empty());
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn smt_subtree_random(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
let mut group = c.benchmark_group("subtree8-rand");
|
||||
|
||||
for pair_count in PAIR_COUNTS {
|
||||
let bench_id = BenchmarkId::from_parameter(pair_count);
|
||||
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Setup.
|
||||
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
|
||||
.map(|i| {
|
||||
let leaf_index: u8 = generate_value(&mut seed);
|
||||
let key = RpoDigest::new([
|
||||
ONE,
|
||||
ONE,
|
||||
Felt::new(i),
|
||||
Felt::new(leaf_index as u64),
|
||||
]);
|
||||
let value = generate_word(&mut seed);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut leaves: Vec<_> = entries
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let leaf = SmtLeaf::new_single(*key, *value);
|
||||
let col = NodeIndex::from(leaf.index()).value();
|
||||
let hash = leaf.hash();
|
||||
SubtreeLeaf { col, hash }
|
||||
})
|
||||
.collect();
|
||||
leaves.sort();
|
||||
leaves.dedup_by_key(|leaf| leaf.col);
|
||||
leaves
|
||||
},
|
||||
|leaves| {
|
||||
let (subtree, _) = build_subtree_for_bench(
|
||||
hint::black_box(leaves),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
hint::black_box(SMT_DEPTH),
|
||||
);
|
||||
assert!(!subtree.is_empty());
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = smt_subtree_group;
|
||||
config = Criterion::default()
|
||||
.measurement_time(Duration::from_secs(40))
|
||||
.sample_size(60)
|
||||
.configure_from_args();
|
||||
targets = smt_subtree_even, smt_subtree_random
|
||||
}
|
||||
criterion_main!(smt_subtree_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let value: [T; 1] = rand_utils::prng_array(*seed);
|
||||
value[0]
|
||||
}
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
71
benches/smt-with-entries.rs
Normal file
71
benches/smt-with-entries.rs
Normal file
|
@ -0,0 +1,71 @@
|
|||
use std::{fmt::Debug, hint, mem, time::Duration};
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
|
||||
use rand_utils::prng_array;
|
||||
use winter_utils::Randomizable;
|
||||
|
||||
// 2^0, 2^4, 2^8, 2^12, 2^16
|
||||
const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576];
|
||||
|
||||
fn smt_with_entries(c: &mut Criterion) {
|
||||
let mut seed = [0u8; 32];
|
||||
|
||||
let mut group = c.benchmark_group("smt-with-entries");
|
||||
|
||||
for pair_count in PAIR_COUNTS {
|
||||
let bench_id = BenchmarkId::from_parameter(pair_count);
|
||||
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
// Setup.
|
||||
prepare_entries(pair_count, &mut seed)
|
||||
},
|
||||
|entries| {
|
||||
// Benchmarked function.
|
||||
Smt::with_entries(hint::black_box(entries)).unwrap();
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group! {
|
||||
name = smt_with_entries_group;
|
||||
config = Criterion::default()
|
||||
//.measurement_time(Duration::from_secs(960))
|
||||
.measurement_time(Duration::from_secs(60))
|
||||
.sample_size(10)
|
||||
.configure_from_args();
|
||||
targets = smt_with_entries
|
||||
}
|
||||
criterion_main!(smt_with_entries_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn prepare_entries(pair_count: u64, seed: &mut [u8; 32]) -> Vec<(RpoDigest, [Felt; 4])> {
|
||||
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
|
||||
.map(|i| {
|
||||
let count = pair_count as f64;
|
||||
let idx = ((i as f64 / count) * (count)) as u64;
|
||||
let key = RpoDigest::new([generate_value(seed), ONE, Felt::new(i), Felt::new(idx)]);
|
||||
let value = generate_word(seed);
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
entries
|
||||
}
|
||||
|
||||
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let value: [T; 1] = rand_utils::prng_array(*seed);
|
||||
value[0]
|
||||
}
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
mem::swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
77
benches/smt.rs
Normal file
77
benches/smt.rs
Normal file
|
@ -0,0 +1,77 @@
|
|||
use core::mem::swap;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use miden_crypto::{
|
||||
merkle::{LeafIndex, SimpleSmt},
|
||||
Felt, Word,
|
||||
};
|
||||
use rand_utils::prng_array;
|
||||
use seq_macro::seq;
|
||||
|
||||
fn smt_rpo(c: &mut Criterion) {
|
||||
// setup trees
|
||||
|
||||
let mut seed = [0u8; 32];
|
||||
let leaf = generate_word(&mut seed);
|
||||
|
||||
seq!(DEPTH in 14..=20 {
|
||||
let leaves = ((1 << DEPTH) - 1) as u64;
|
||||
for count in [1, leaves / 2, leaves] {
|
||||
let entries: Vec<_> = (0..count)
|
||||
.map(|i| {
|
||||
let word = generate_word(&mut seed);
|
||||
(i, word)
|
||||
})
|
||||
.collect();
|
||||
let mut tree = SimpleSmt::<DEPTH>::with_leaves(entries).unwrap();
|
||||
|
||||
// benchmark 1
|
||||
let mut insert = c.benchmark_group("smt update_leaf".to_string());
|
||||
{
|
||||
let depth = DEPTH;
|
||||
let key = count >> 2;
|
||||
insert.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&(key, leaf),
|
||||
|b, (key, leaf)| {
|
||||
b.iter(|| {
|
||||
tree.insert(black_box(LeafIndex::<DEPTH>::new(*key).unwrap()), black_box(*leaf));
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
}
|
||||
insert.finish();
|
||||
|
||||
// benchmark 2
|
||||
let mut path = c.benchmark_group("smt get_leaf_path".to_string());
|
||||
{
|
||||
let depth = DEPTH;
|
||||
let key = count >> 2;
|
||||
path.bench_with_input(
|
||||
format!("simple smt(depth:{depth},count:{count})"),
|
||||
&key,
|
||||
|b, key| {
|
||||
b.iter(|| {
|
||||
tree.open(black_box(&LeafIndex::<DEPTH>::new(*key).unwrap()));
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
}
|
||||
path.finish();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(smt_group, smt_rpo);
|
||||
criterion_main!(smt_group);
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn generate_word(seed: &mut [u8; 32]) -> Word {
|
||||
swap(seed, &mut prng_array(*seed));
|
||||
let nums: [u64; 4] = prng_array(*seed);
|
||||
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
|
||||
}
|
487
benches/store.rs
Normal file
487
benches/store.rs
Normal file
|
@ -0,0 +1,487 @@
|
|||
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
|
||||
use miden_crypto::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{
|
||||
DefaultMerkleStore as MerkleStore, LeafIndex, MerkleTree, NodeIndex, SimpleSmt,
|
||||
SMT_MAX_DEPTH,
|
||||
},
|
||||
Felt, Word,
|
||||
};
|
||||
use rand_utils::{rand_array, rand_value};
|
||||
|
||||
/// Since MerkleTree can only be created when a power-of-two number of elements is used, the sample
|
||||
/// sizes are limited to that.
|
||||
static BATCH_SIZES: [usize; 3] = [2usize.pow(4), 2usize.pow(7), 2usize.pow(10)];
|
||||
|
||||
/// Generates a random `RpoDigest`.
|
||||
fn random_rpo_digest() -> RpoDigest {
|
||||
rand_array::<Felt, 4>().into()
|
||||
}
|
||||
|
||||
/// Generates a random `Word`.
|
||||
fn random_word() -> Word {
|
||||
rand_array::<Felt, 4>()
|
||||
}
|
||||
|
||||
/// Generates an index at the specified depth in `0..range`.
|
||||
fn random_index(range: u64, depth: u8) -> NodeIndex {
|
||||
let value = rand_value::<u64>() % range;
|
||||
NodeIndex::new(depth, value).unwrap()
|
||||
}
|
||||
|
||||
/// Benchmarks getting an empty leaf from the SMT and MerkleStore backends.
|
||||
fn get_empty_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_empty_leaf_simplesmt");
|
||||
|
||||
const DEPTH: u8 = SMT_MAX_DEPTH;
|
||||
let size = u64::MAX;
|
||||
|
||||
// both SMT and the store are pre-populated with empty hashes, accessing these values is what is
|
||||
// being benchmarked here, so no values are inserted into the backends
|
||||
let smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size, DEPTH),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size, DEPTH),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmarks getting a leaf on Merkle trees and Merkle stores of varying power-of-two sizes.
|
||||
fn get_leaf_merkletree(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_leaf_merkletree");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
|
||||
let mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let depth = mtree.depth();
|
||||
let root = mtree.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(mtree.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks getting a leaf on SMT and Merkle stores of varying power-of-two sizes.
|
||||
fn get_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_leaf_simplesmt");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let smt_leaves = leaves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks getting a node at half of the depth of an empty SMT and an empty Merkle store.
|
||||
fn get_node_of_empty_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_node_of_empty_simplesmt");
|
||||
|
||||
const DEPTH: u8 = SMT_MAX_DEPTH;
|
||||
|
||||
// both SMT and the store are pre-populated with the empty hashes, accessing the internal nodes
|
||||
// of these values is what is being benchmarked here, so no values are inserted into the
|
||||
// backends.
|
||||
let smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = DEPTH / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
/// Benchmarks getting a node at half of the depth of a Merkle tree and Merkle store of varying
|
||||
/// power-of-two sizes.
|
||||
fn get_node_merkletree(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_node_merkletree");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
|
||||
let mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let root = mtree.root();
|
||||
let half_depth = mtree.depth() / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(mtree.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks getting a node at half the depth on SMT and Merkle stores of varying power-of-two
|
||||
/// sizes.
|
||||
fn get_node_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_node_simplesmt");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let smt_leaves = leaves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let half_depth = SMT_MAX_DEPTH / 2;
|
||||
let half_size = 2_u64.pow(half_depth as u32);
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(smt.get_node(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(half_size, half_depth),
|
||||
|index| black_box(store.get_node(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks getting a path of a leaf on the Merkle tree and Merkle store backends.
|
||||
fn get_leaf_path_merkletree(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_leaf_path_merkletree");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
|
||||
let mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let depth = mtree.depth();
|
||||
let root = mtree.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(mtree.get_path(index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, depth),
|
||||
|index| black_box(store.get_path(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks getting a path of a leaf on the SMT and Merkle store backends.
|
||||
fn get_leaf_path_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("get_leaf_path_simplesmt");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let smt_leaves = leaves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| {
|
||||
black_box(smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(index.value()).unwrap()))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| random_index(size_u64, SMT_MAX_DEPTH),
|
||||
|index| black_box(store.get_path(root, index)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks creation of the different storage backends
|
||||
fn new(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("new");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
// MerkleTree constructor is optimized to work with vectors. Create a new copy of the data
|
||||
// and pass it to the benchmark function
|
||||
group.bench_function(BenchmarkId::new("MerkleTree::new", size), |b| {
|
||||
b.iter_batched(
|
||||
|| leaves.iter().map(|v| v.into()).collect::<Vec<Word>>(),
|
||||
|l| black_box(MerkleTree::new(l)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
// This could be done with `bench_with_input`, however to remove variables while comparing
|
||||
// with MerkleTree it is using `iter_batched`
|
||||
group.bench_function(BenchmarkId::new("MerkleStore::extend::MerkleTree", size), |b| {
|
||||
b.iter_batched(
|
||||
|| leaves.iter().map(|v| v.into()).collect::<Vec<Word>>(),
|
||||
|l| {
|
||||
let mtree = MerkleTree::new(l).unwrap();
|
||||
black_box(MerkleStore::from(&mtree));
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSmt::new", size), |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
leaves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| black_box(SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleStore::extend::SimpleSmt", size), |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
leaves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>()
|
||||
},
|
||||
|l| {
|
||||
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l).unwrap();
|
||||
black_box(MerkleStore::from(&smt));
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks updating a leaf on MerkleTree and MerkleStore backends.
|
||||
fn update_leaf_merkletree(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("update_leaf_merkletree");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let mtree_leaves: Vec<Word> = leaves.iter().map(|v| v.into()).collect();
|
||||
let mut mtree = MerkleTree::new(mtree_leaves.clone()).unwrap();
|
||||
let mut store = MerkleStore::from(&mtree);
|
||||
let depth = mtree.depth();
|
||||
let root = mtree.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("MerkleTree", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (rand_value::<u64>() % size_u64, random_word()),
|
||||
|(index, value)| black_box(mtree.update_leaf(index, value)),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
let mut store_root = root;
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (random_index(size_u64, depth), random_word()),
|
||||
|(index, value)| {
|
||||
// The MerkleTree automatically updates its internal root, the Store maintains
|
||||
// the old root and adds the new one. Here we update the root to have a fair
|
||||
// comparison
|
||||
store_root = store.set_node(root, index, value.into()).unwrap().root;
|
||||
black_box(store_root)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmarks updating a leaf on SMT and MerkleStore backends.
|
||||
fn update_leaf_simplesmt(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("update_leaf_simplesmt");
|
||||
|
||||
let random_data_size = BATCH_SIZES.into_iter().max().unwrap();
|
||||
let random_data: Vec<RpoDigest> = (0..random_data_size).map(|_| random_rpo_digest()).collect();
|
||||
|
||||
for size in BATCH_SIZES {
|
||||
let leaves = &random_data[..size];
|
||||
|
||||
let smt_leaves = leaves
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
|
||||
.collect::<Vec<(u64, Word)>>();
|
||||
let mut smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
|
||||
let mut store = MerkleStore::from(&smt);
|
||||
let root = smt.root();
|
||||
let size_u64 = size as u64;
|
||||
|
||||
group.bench_function(BenchmarkId::new("SimpleSMT", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (rand_value::<u64>() % size_u64, random_word()),
|
||||
|(index, value)| {
|
||||
black_box(smt.insert(LeafIndex::<SMT_MAX_DEPTH>::new(index).unwrap(), value))
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
|
||||
let mut store_root = root;
|
||||
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
|
||||
b.iter_batched(
|
||||
|| (random_index(size_u64, SMT_MAX_DEPTH), random_word()),
|
||||
|(index, value)| {
|
||||
// The MerkleTree automatically updates its internal root, the Store maintains
|
||||
// the old root and adds the new one. Here we update the root to have a fair
|
||||
// comparison
|
||||
store_root = store.set_node(root, index, value.into()).unwrap().root;
|
||||
black_box(store_root)
|
||||
},
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
store_group,
|
||||
get_empty_leaf_simplesmt,
|
||||
get_leaf_merkletree,
|
||||
get_leaf_path_merkletree,
|
||||
get_leaf_path_simplesmt,
|
||||
get_leaf_simplesmt,
|
||||
get_node_merkletree,
|
||||
get_node_of_empty_simplesmt,
|
||||
get_node_simplesmt,
|
||||
new,
|
||||
update_leaf_merkletree,
|
||||
update_leaf_simplesmt,
|
||||
);
|
||||
criterion_main!(store_group);
|
19
build.rs
Normal file
19
build.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
fn main() {
|
||||
#[cfg(target_feature = "sve")]
|
||||
compile_arch_arm64_sve();
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "sve")]
|
||||
fn compile_arch_arm64_sve() {
|
||||
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";
|
||||
|
||||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.c");
|
||||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/library.h");
|
||||
println!("cargo:rerun-if-changed={RPO_SVE_PATH}/rpo_hash.h");
|
||||
|
||||
cc::Build::new()
|
||||
.file(format!("{RPO_SVE_PATH}/library.c"))
|
||||
.flag("-march=armv8-a+sve")
|
||||
.flag("-O3")
|
||||
.compile("rpo_sve");
|
||||
}
|
5
rust-toolchain.toml
Normal file
5
rust-toolchain.toml
Normal file
|
@ -0,0 +1,5 @@
|
|||
[toolchain]
|
||||
channel = "1.84"
|
||||
components = ["rustfmt", "rust-src", "clippy"]
|
||||
targets = ["wasm32-unknown-unknown"]
|
||||
profile = "minimal"
|
23
rustfmt.toml
Normal file
23
rustfmt.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
edition = "2021"
|
||||
array_width = 80
|
||||
attr_fn_like_width = 80
|
||||
chain_width = 80
|
||||
comment_width = 100
|
||||
condense_wildcard_suffixes = true
|
||||
fn_call_width = 80
|
||||
format_code_in_doc_comments = true
|
||||
format_macro_matchers = true
|
||||
group_imports = "StdExternalCrate"
|
||||
hex_literal_case = "Lower"
|
||||
imports_granularity = "Crate"
|
||||
match_block_trailing_comma = true
|
||||
newline_style = "Unix"
|
||||
reorder_imports = true
|
||||
reorder_modules = true
|
||||
single_line_if_else_max_width = 60
|
||||
single_line_let_else_max_width = 60
|
||||
struct_lit_width = 40
|
||||
struct_variant_width = 40
|
||||
use_field_init_shorthand = true
|
||||
use_try_shorthand = true
|
||||
wrap_comments = true
|
21
scripts/check-changelog.sh
Executable file
21
scripts/check-changelog.sh
Executable file
|
@ -0,0 +1,21 @@
|
|||
#!/bin/bash
|
||||
set -uo pipefail
|
||||
|
||||
CHANGELOG_FILE="${1:-CHANGELOG.md}"
|
||||
|
||||
if [ "${NO_CHANGELOG_LABEL}" = "true" ]; then
|
||||
# 'no changelog' set, so finish successfully
|
||||
echo "\"no changelog\" label has been set"
|
||||
exit 0
|
||||
else
|
||||
# a changelog check is required
|
||||
# fail if the diff is empty
|
||||
if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then
|
||||
>&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior
|
||||
can be overridden by using the \"no changelog\" label, which is used for changes
|
||||
that are trivial / explicitly stated not to require a changelog entry."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "The \"CHANGELOG.md\" file has been updated."
|
||||
fi
|
15
scripts/check-rust-version.sh
Executable file
15
scripts/check-rust-version.sh
Executable file
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Get rust-toolchain.toml file channel
|
||||
TOOLCHAIN_VERSION=$(grep 'channel' rust-toolchain.toml | sed -E 's/.*"(.*)".*/\1/')
|
||||
|
||||
# Get workspace Cargo.toml file rust-version
|
||||
CARGO_VERSION=$(grep 'rust-version' Cargo.toml | sed -E 's/.*"(.*)".*/\1/')
|
||||
|
||||
# Check version match
|
||||
if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then
|
||||
echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Rust versions match ✅"
|
1
series
1
series
|
@ -1 +0,0 @@
|
|||
Subproject commit fa0943fc4864a76c98516177e9f7d781d35a57e6
|
3
src/dsa/mod.rs
Normal file
3
src/dsa/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
//! Digital signature schemes supported by default in the Miden VM.
|
||||
|
||||
pub mod rpo_falcon512;
|
70
src/dsa/rpo_falcon512/hash_to_point.rs
Normal file
70
src/dsa/rpo_falcon512/hash_to_point.rs
Normal file
|
@ -0,0 +1,70 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use num::Zero;
|
||||
|
||||
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
|
||||
|
||||
// HASH-TO-POINT FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce using RPO256.
|
||||
pub fn hash_to_point_rpo256(message: Word, nonce: &Nonce) -> Polynomial<FalconFelt> {
|
||||
let mut state = [ZERO; Rpo256::STATE_WIDTH];
|
||||
|
||||
// absorb the nonce into the state
|
||||
let nonce_elements = nonce.to_elements();
|
||||
for (&n, s) in nonce_elements.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = n;
|
||||
}
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
// absorb message into the state
|
||||
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
|
||||
*s = m;
|
||||
}
|
||||
|
||||
// squeeze the coefficients of the polynomial
|
||||
let mut i = 0;
|
||||
let mut res = [FalconFelt::zero(); N];
|
||||
for _ in 0..64 {
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
for a in &state[Rpo256::RATE_RANGE] {
|
||||
res[i] = FalconFelt::new((a.as_int() % MODULUS as u64) as i16);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial::new(res.to_vec())
|
||||
}
|
||||
|
||||
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
|
||||
/// nonce using SHAKE256. This is the hash-to-point algorithm used in the reference implementation.
|
||||
#[allow(dead_code)]
|
||||
pub fn hash_to_point_shake256(message: &[u8], nonce: &Nonce) -> Polynomial<FalconFelt> {
|
||||
use sha3::{
|
||||
digest::{ExtendableOutput, Update, XofReader},
|
||||
Shake256,
|
||||
};
|
||||
|
||||
let mut data = vec![];
|
||||
data.extend_from_slice(nonce.as_bytes());
|
||||
data.extend_from_slice(message);
|
||||
const K: u32 = (1u32 << 16) / MODULUS as u32;
|
||||
|
||||
let mut hasher = Shake256::default();
|
||||
hasher.update(&data);
|
||||
let mut reader = hasher.finalize_xof();
|
||||
|
||||
let mut coefficients: Vec<FalconFelt> = Vec::with_capacity(N);
|
||||
while coefficients.len() != N {
|
||||
let mut randomness = [0u8; 2];
|
||||
reader.read(&mut randomness);
|
||||
let t = ((randomness[0] as u32) << 8) | (randomness[1] as u32);
|
||||
if t < K * MODULUS as u32 {
|
||||
coefficients.push(FalconFelt::new((t % MODULUS as u32) as i16));
|
||||
}
|
||||
}
|
||||
|
||||
Polynomial { coefficients }
|
||||
}
|
55
src/dsa/rpo_falcon512/keys/mod.rs
Normal file
55
src/dsa/rpo_falcon512/keys/mod.rs
Normal file
|
@ -0,0 +1,55 @@
|
|||
use super::{
|
||||
math::{FalconFelt, Polynomial},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Serializable, Signature,
|
||||
Word,
|
||||
};
|
||||
|
||||
mod public_key;
|
||||
pub use public_key::{PubKeyPoly, PublicKey};
|
||||
|
||||
mod secret_key;
|
||||
pub use secret_key::SecretKey;
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
use winter_math::FieldElement;
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
|
||||
|
||||
#[test]
|
||||
fn test_falcon_verification() {
|
||||
let seed = [0_u8; 32];
|
||||
let mut rng = ChaCha20Rng::from_seed(seed);
|
||||
|
||||
// generate random keys
|
||||
let sk = SecretKey::with_rng(&mut rng);
|
||||
let pk = sk.public_key();
|
||||
|
||||
// test secret key serialization/deserialization
|
||||
let mut buffer = vec![];
|
||||
sk.write_into(&mut buffer);
|
||||
let sk_deserialized = SecretKey::read_from_bytes(&buffer).unwrap();
|
||||
assert_eq!(sk.short_lattice_basis(), sk_deserialized.short_lattice_basis());
|
||||
|
||||
// sign a random message
|
||||
let message: Word = [ONE; 4];
|
||||
let signature = sk.sign_with_rng(message, &mut rng);
|
||||
|
||||
// make sure the signature verifies correctly
|
||||
assert!(pk.verify(message, &signature));
|
||||
|
||||
// a signature should not verify against a wrong message
|
||||
let message2: Word = [ONE.double(); 4];
|
||||
assert!(!pk.verify(message2, &signature));
|
||||
|
||||
// a signature should not verify against a wrong public key
|
||||
let sk2 = SecretKey::with_rng(&mut rng);
|
||||
assert!(!sk2.public_key().verify(message, &signature))
|
||||
}
|
||||
}
|
139
src/dsa/rpo_falcon512/keys/public_key.rs
Normal file
139
src/dsa/rpo_falcon512/keys/public_key.rs
Normal file
|
@ -0,0 +1,139 @@
|
|||
use alloc::string::ToString;
|
||||
use core::ops::Deref;
|
||||
|
||||
use num::Zero;
|
||||
|
||||
use super::{
|
||||
super::{Rpo256, LOG_N, N, PK_LEN},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial,
|
||||
Serializable, Signature, Word,
|
||||
};
|
||||
use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS;
|
||||
|
||||
// PUBLIC KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// A public key for verifying signatures.
|
||||
///
|
||||
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
|
||||
/// the polynomial representing the raw bytes of the expanded public key. The hash is computed
|
||||
/// using Rpo256.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct PublicKey(Word);
|
||||
|
||||
impl PublicKey {
|
||||
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
|
||||
pub fn new(pub_key: Word) -> Self {
|
||||
Self(pub_key)
|
||||
}
|
||||
|
||||
/// Verifies the provided signature against provided message and this public key.
|
||||
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
|
||||
signature.verify(message, self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PubKeyPoly> for PublicKey {
|
||||
fn from(pk_poly: PubKeyPoly) -> Self {
|
||||
let pk_felts: Polynomial<Felt> = pk_poly.0.into();
|
||||
let pk_digest = Rpo256::hash_elements(&pk_felts.coefficients).into();
|
||||
Self(pk_digest)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PublicKey> for Word {
|
||||
fn from(key: PublicKey) -> Self {
|
||||
key.0
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC KEY POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PubKeyPoly(pub Polynomial<FalconFelt>);
|
||||
|
||||
impl Deref for PubKeyPoly {
|
||||
type Target = Polynomial<FalconFelt>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for PubKeyPoly {
|
||||
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
|
||||
Self(pk_poly)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &PubKeyPoly {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let mut buf = [0_u8; PK_LEN];
|
||||
buf[0] = LOG_N;
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len: u32 = 0;
|
||||
|
||||
let mut input_pos = 1;
|
||||
for c in self.0.coefficients.iter() {
|
||||
let c = c.value();
|
||||
acc = (acc << FALCON_ENCODING_BITS) | c as u32;
|
||||
acc_len += FALCON_ENCODING_BITS;
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
buf[input_pos] = (acc >> acc_len) as u8;
|
||||
input_pos += 1;
|
||||
}
|
||||
}
|
||||
if acc_len > 0 {
|
||||
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
|
||||
}
|
||||
|
||||
target.write(buf);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PubKeyPoly {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let buf = source.read_array::<PK_LEN>()?;
|
||||
|
||||
if buf[0] != LOG_N {
|
||||
return Err(DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode public key: expected the first byte to be {LOG_N} but was {}",
|
||||
buf[0]
|
||||
)));
|
||||
}
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
|
||||
let mut output = [FalconFelt::zero(); N];
|
||||
let mut output_idx = 0;
|
||||
|
||||
for &byte in buf.iter().skip(1) {
|
||||
acc = (acc << 8) | (byte as u32);
|
||||
acc_len += 8;
|
||||
|
||||
if acc_len >= FALCON_ENCODING_BITS {
|
||||
acc_len -= FALCON_ENCODING_BITS;
|
||||
let w = (acc >> acc_len) & 0x3fff;
|
||||
let element = w.try_into().map_err(|err| {
|
||||
DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode public key: {err}"
|
||||
))
|
||||
})?;
|
||||
output[output_idx] = element;
|
||||
output_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Ok(Polynomial::new(output.to_vec()).into())
|
||||
} else {
|
||||
Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode public key: input not fully consumed".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
401
src/dsa/rpo_falcon512/keys/secret_key.rs
Normal file
401
src/dsa/rpo_falcon512/keys/secret_key.rs
Normal file
|
@ -0,0 +1,401 @@
|
|||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use num::Complex;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use num_complex::Complex64;
|
||||
use rand::Rng;
|
||||
|
||||
use super::{
|
||||
super::{
|
||||
math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial},
|
||||
signature::SignaturePoly,
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Nonce, Serializable,
|
||||
ShortLatticeBasis, Signature, Word, MODULUS, N, SIGMA, SIG_L2_BOUND,
|
||||
},
|
||||
PubKeyPoly, PublicKey,
|
||||
};
|
||||
use crate::dsa::rpo_falcon512::{
|
||||
hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN,
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const WIDTH_BIG_POLY_COEFFICIENT: usize = 8;
|
||||
const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
|
||||
|
||||
// SECRET KEY
|
||||
// ================================================================================================
|
||||
|
||||
/// Represents the secret key for Falcon DSA.
|
||||
///
|
||||
/// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each
|
||||
/// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo
|
||||
/// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has
|
||||
/// the property of being short with respect to a certain norm and an upper bound appropriate for
|
||||
/// a given security parameter. The public key on the other hand is another basis for the same
|
||||
/// lattice and can be described by a single polynomial h with integer coefficients modulo ϕ.
|
||||
/// The two keys are related by the following relation:
|
||||
///
|
||||
/// 1. h = g /f [mod ϕ][mod p]
|
||||
/// 2. f.G - g.F = p [mod ϕ]
|
||||
///
|
||||
/// where p = 12289 is the Falcon prime. Equation 2 is called the NTRU equation.
|
||||
/// The secret key is generated by first sampling a random pair (f, g) of polynomials using
|
||||
/// an appropriate distribution that yields short but not too short polynomials with integer
|
||||
/// coefficients modulo ϕ. The NTRU equation is then used to find a matching pair (F, G).
|
||||
/// The public key is then derived from the secret key using equation 1.
|
||||
///
|
||||
/// To allow for fast signature generation, the secret key is pre-processed into a more suitable
|
||||
/// form, called the LDL tree, and this allows for fast sampling of short vectors in the lattice
|
||||
/// using Fast Fourier sampling during signature generation (ffSampling algorithm 11 in [1]).
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecretKey {
|
||||
secret_key: ShortLatticeBasis,
|
||||
tree: LdlTree,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl SecretKey {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Generates a secret key from OS-provided randomness.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn new() -> Self {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
Self::with_rng(&mut rng)
|
||||
}
|
||||
|
||||
/// Generates a secret_key using the provided random number generator `Rng`.
|
||||
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
|
||||
let basis = ntru_gen(N, rng);
|
||||
Self::from_short_lattice_basis(basis)
|
||||
}
|
||||
|
||||
/// Given a short basis [[g, -f], [G, -F]], computes the normalized LDL tree i.e., Falcon tree.
|
||||
fn from_short_lattice_basis(basis: ShortLatticeBasis) -> SecretKey {
|
||||
// FFT each polynomial of the short basis.
|
||||
let basis_fft = to_complex_fft(&basis);
|
||||
// compute the Gram matrix.
|
||||
let gram_fft = gram(basis_fft);
|
||||
// construct the LDL tree of the Gram matrix.
|
||||
let mut tree = ffldl(gram_fft);
|
||||
// normalize the leaves of the LDL tree.
|
||||
normalize_tree(&mut tree, SIGMA);
|
||||
Self { secret_key: basis, tree }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the polynomials of the short lattice basis of this secret key.
|
||||
pub fn short_lattice_basis(&self) -> &ShortLatticeBasis {
|
||||
&self.secret_key
|
||||
}
|
||||
|
||||
/// Returns the public key corresponding to this secret key.
|
||||
pub fn public_key(&self) -> PublicKey {
|
||||
self.compute_pub_key_poly().into()
|
||||
}
|
||||
|
||||
/// Returns the LDL tree associated to this secret key.
|
||||
pub fn tree(&self) -> &LdlTree {
|
||||
&self.tree
|
||||
}
|
||||
|
||||
// SIGNATURE GENERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Signs a message with this secret key.
|
||||
#[cfg(feature = "std")]
|
||||
pub fn sign(&self, message: Word) -> Signature {
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
let mut rng = StdRng::from_entropy();
|
||||
self.sign_with_rng(message, &mut rng)
|
||||
}
|
||||
|
||||
/// Signs a message with the secret key relying on the provided randomness generator.
|
||||
pub fn sign_with_rng<R: Rng>(&self, message: Word, rng: &mut R) -> Signature {
|
||||
let mut nonce_bytes = [0u8; SIG_NONCE_LEN];
|
||||
rng.fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::new(nonce_bytes);
|
||||
|
||||
let h = self.compute_pub_key_poly();
|
||||
let c = hash_to_point_rpo256(message, &nonce);
|
||||
let s2 = self.sign_helper(c, rng);
|
||||
|
||||
Signature::new(nonce, h, s2)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Derives the public key corresponding to this secret key using h = g /f [mod ϕ][mod p].
|
||||
pub fn compute_pub_key_poly(&self) -> PubKeyPoly {
|
||||
let g: Polynomial<FalconFelt> = self.secret_key[0].clone().into();
|
||||
let g_fft = g.fft();
|
||||
let minus_f: Polynomial<FalconFelt> = self.secret_key[1].clone().into();
|
||||
let f = -minus_f;
|
||||
let f_fft = f.fft();
|
||||
let h_fft = g_fft.hadamard_div(&f_fft);
|
||||
h_fft.ifft().into()
|
||||
}
|
||||
|
||||
/// Signs a message polynomial with the secret key.
|
||||
///
|
||||
/// Takes a randomness generator implementing `Rng` and message polynomial representing `c`
|
||||
/// the hash-to-point of the message to be signed. It outputs a signature polynomial `s2`.
|
||||
fn sign_helper<R: Rng>(&self, c: Polynomial<FalconFelt>, rng: &mut R) -> SignaturePoly {
|
||||
let one_over_q = 1.0 / (MODULUS as f64);
|
||||
let c_over_q_fft = c.map(|cc| Complex::new(one_over_q * cc.value() as f64, 0.0)).fft();
|
||||
|
||||
// B = [[FFT(g), -FFT(f)], [FFT(G), -FFT(F)]]
|
||||
let [g_fft, minus_f_fft, big_g_fft, minus_big_f_fft] = to_complex_fft(&self.secret_key);
|
||||
let t0 = c_over_q_fft.hadamard_mul(&minus_big_f_fft);
|
||||
let t1 = -c_over_q_fft.hadamard_mul(&minus_f_fft);
|
||||
|
||||
loop {
|
||||
let bold_s = loop {
|
||||
let z = ffsampling(&(t0.clone(), t1.clone()), &self.tree, rng);
|
||||
let t0_min_z0 = t0.clone() - z.0;
|
||||
let t1_min_z1 = t1.clone() - z.1;
|
||||
|
||||
// s = (t-z) * B
|
||||
let s0 = t0_min_z0.hadamard_mul(&g_fft) + t1_min_z1.hadamard_mul(&big_g_fft);
|
||||
let s1 =
|
||||
t0_min_z0.hadamard_mul(&minus_f_fft) + t1_min_z1.hadamard_mul(&minus_big_f_fft);
|
||||
|
||||
// compute the norm of (s0||s1) and note that they are in FFT representation
|
||||
let length_squared: f64 =
|
||||
(s0.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>()
|
||||
+ s1.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>())
|
||||
/ (N as f64);
|
||||
|
||||
if length_squared > (SIG_L2_BOUND as f64) {
|
||||
continue;
|
||||
}
|
||||
|
||||
break [-s0, s1];
|
||||
};
|
||||
|
||||
let s2 = bold_s[1].ifft();
|
||||
let s2_coef: [i16; N] = s2
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|a| a.re.round() as i16)
|
||||
.collect::<Vec<i16>>()
|
||||
.try_into()
|
||||
.expect("The number of coefficients should be equal to N");
|
||||
|
||||
if let Ok(s2) = SignaturePoly::try_from(&s2_coef) {
|
||||
return s2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for SecretKey {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let basis = &self.secret_key;
|
||||
|
||||
// header
|
||||
let n = basis[0].coefficients.len();
|
||||
let l = n.checked_ilog2().unwrap() as u8;
|
||||
let header: u8 = (5 << 4) | l;
|
||||
|
||||
let neg_f = &basis[1];
|
||||
let g = &basis[0];
|
||||
let neg_big_f = &basis[3];
|
||||
|
||||
let mut buffer = Vec::with_capacity(1281);
|
||||
buffer.push(header);
|
||||
|
||||
let f_i8: Vec<i8> = neg_f
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
|
||||
.collect();
|
||||
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&f_i8_encoded);
|
||||
|
||||
let g_i8: Vec<i8> = g
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|&a| FalconFelt::new(a).balanced_value() as i8)
|
||||
.collect();
|
||||
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&g_i8_encoded);
|
||||
|
||||
let big_f_i8: Vec<i8> = neg_big_f
|
||||
.coefficients
|
||||
.iter()
|
||||
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
|
||||
.collect();
|
||||
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
|
||||
buffer.extend_from_slice(&big_f_i8_encoded);
|
||||
target.write_bytes(&buffer);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SecretKey {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let byte_vector: [u8; SK_LEN] = source.read_array()?;
|
||||
|
||||
// check length
|
||||
if byte_vector.len() < 2 {
|
||||
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
|
||||
}
|
||||
|
||||
// read fields
|
||||
let header = byte_vector[0];
|
||||
|
||||
// check fixed bits in header
|
||||
if (header >> 4) != 5 {
|
||||
return Err(DeserializationError::InvalidValue("Invalid header format".to_string()));
|
||||
}
|
||||
|
||||
// check log n
|
||||
let logn = (header & 15) as usize;
|
||||
let n = 1 << logn;
|
||||
|
||||
// match against const variant generic parameter
|
||||
if n != N {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Unsupported Falcon DSA variant".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if byte_vector.len() != SK_LEN {
|
||||
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
|
||||
}
|
||||
|
||||
let chunk_size_f = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
|
||||
let chunk_size_g = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
|
||||
let chunk_size_big_f = ((n * WIDTH_BIG_POLY_COEFFICIENT) + 7) >> 3;
|
||||
|
||||
let f = decode_i8(&byte_vector[1..chunk_size_f + 1], WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
|
||||
let g = decode_i8(
|
||||
&byte_vector[chunk_size_f + 1..(chunk_size_f + chunk_size_g + 1)],
|
||||
WIDTH_SMALL_POLY_COEFFICIENT,
|
||||
)
|
||||
.unwrap();
|
||||
let big_f = decode_i8(
|
||||
&byte_vector[(chunk_size_f + chunk_size_g + 1)
|
||||
..(chunk_size_f + chunk_size_g + chunk_size_big_f + 1)],
|
||||
WIDTH_BIG_POLY_COEFFICIENT,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let f = Polynomial::new(f.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
let g = Polynomial::new(g.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
let big_f = Polynomial::new(big_f.iter().map(|&c| FalconFelt::new(c.into())).collect());
|
||||
|
||||
// big_g * f - g * big_f = p (mod X^n + 1)
|
||||
let big_g = g.fft().hadamard_div(&f.fft()).hadamard_mul(&big_f.fft()).ifft();
|
||||
let basis = [
|
||||
g.map(|f| f.balanced_value()),
|
||||
-f.map(|f| f.balanced_value()),
|
||||
big_g.map(|f| f.balanced_value()),
|
||||
-big_f.map(|f| f.balanced_value()),
|
||||
];
|
||||
Ok(Self::from_short_lattice_basis(basis))
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Computes the complex FFT of the secret key polynomials.
|
||||
fn to_complex_fft(basis: &[Polynomial<i16>; 4]) -> [Polynomial<Complex<f64>>; 4] {
|
||||
let [g, f, big_g, big_f] = basis.clone();
|
||||
let g_fft = g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let minus_f_fft = f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let big_g_fft = big_g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
|
||||
let minus_big_f_fft = big_f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
|
||||
[g_fft, minus_f_fft, big_g_fft, minus_big_f_fft]
|
||||
}
|
||||
|
||||
/// Encodes a sequence of signed integers such that each integer x satisfies |x| < 2^(bits-1)
|
||||
/// for a given parameter bits. bits can take either the value 6 or 8.
|
||||
pub fn encode_i8(x: &[i8], bits: usize) -> Option<Vec<u8>> {
|
||||
let maxv = (1 << (bits - 1)) - 1_usize;
|
||||
let maxv = maxv as i8;
|
||||
let minv = -maxv;
|
||||
|
||||
for &c in x {
|
||||
if c > maxv || c < minv {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let out_len = ((N * bits) + 7) >> 3;
|
||||
let mut buf = vec![0_u8; out_len];
|
||||
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
let mask = ((1_u16 << bits) - 1) as u8;
|
||||
|
||||
let mut input_pos = 0;
|
||||
for &c in x {
|
||||
acc = (acc << bits) | (c as u8 & mask) as u32;
|
||||
acc_len += bits;
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
buf[input_pos] = (acc >> acc_len) as u8;
|
||||
input_pos += 1;
|
||||
}
|
||||
}
|
||||
if acc_len > 0 {
|
||||
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
|
||||
}
|
||||
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
/// Decodes a sequence of bytes into a sequence of signed integers such that each integer x
|
||||
/// satisfies |x| < 2^(bits-1) for a given parameter bits. bits can take either the value 6 or 8.
|
||||
pub fn decode_i8(buf: &[u8], bits: usize) -> Option<Vec<i8>> {
|
||||
let mut x = [0_i8; N];
|
||||
|
||||
let mut i = 0;
|
||||
let mut j = 0;
|
||||
let mut acc = 0_u32;
|
||||
let mut acc_len = 0;
|
||||
let mask = (1_u32 << bits) - 1;
|
||||
let a = (1 << bits) as u8;
|
||||
let b = ((1 << (bits - 1)) - 1) as u8;
|
||||
|
||||
while i < N {
|
||||
acc = (acc << 8) | (buf[j] as u32);
|
||||
j += 1;
|
||||
acc_len += 8;
|
||||
|
||||
while acc_len >= bits && i < N {
|
||||
acc_len -= bits;
|
||||
let w = (acc >> acc_len) & mask;
|
||||
|
||||
let w = w as u8;
|
||||
|
||||
let z = if w > b { w as i8 - a as i8 } else { w as i8 };
|
||||
|
||||
x[i] = z;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (acc & ((1u32 << acc_len) - 1)) == 0 {
|
||||
Some(x.to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
124
src/dsa/rpo_falcon512/math/ffsampling.rs
Normal file
124
src/dsa/rpo_falcon512/math/ffsampling.rs
Normal file
|
@ -0,0 +1,124 @@
|
|||
use alloc::boxed::Box;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use num::{One, Zero};
|
||||
use num_complex::{Complex, Complex64};
|
||||
use rand::Rng;
|
||||
|
||||
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
|
||||
|
||||
const SIGMIN: f64 = 1.2778336969128337;
|
||||
|
||||
/// Computes the Gram matrix. The argument must be a 2x2 matrix
|
||||
/// whose elements are equal-length vectors of complex numbers,
|
||||
/// representing polynomials in FFT domain.
|
||||
pub fn gram(b: [Polynomial<Complex64>; 4]) -> [Polynomial<Complex64>; 4] {
|
||||
const N: usize = 2;
|
||||
let mut g: [Polynomial<Complex<f64>>; 4] =
|
||||
[Polynomial::zero(), Polynomial::zero(), Polynomial::zero(), Polynomial::zero()];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
for k in 0..N {
|
||||
g[N * i + j] = g[N * i + j].clone()
|
||||
+ b[N * i + k].hadamard_mul(&b[N * j + k].map(|c| c.conj()));
|
||||
}
|
||||
}
|
||||
}
|
||||
g
|
||||
}
|
||||
|
||||
/// Computes the LDL decomposition of a 2x2 matrix G such that
|
||||
/// L D L* = G
|
||||
/// where D is diagonal, and L is lower-triangular. The elements of the matrices are in FFT domain.
|
||||
pub fn ldl(
|
||||
g: [Polynomial<Complex64>; 4],
|
||||
) -> ([Polynomial<Complex64>; 4], [Polynomial<Complex64>; 4]) {
|
||||
let zero = Polynomial::<Complex64>::one();
|
||||
let one = Polynomial::<Complex64>::zero();
|
||||
|
||||
let l10 = g[2].hadamard_div(&g[0]);
|
||||
let bc = l10.map(|c| c * c.conj());
|
||||
let abc = g[0].hadamard_mul(&bc);
|
||||
let d11 = g[3].clone() - abc;
|
||||
|
||||
let l = [one.clone(), zero.clone(), l10.clone(), one];
|
||||
let d = [g[0].clone(), zero.clone(), zero, d11];
|
||||
(l, d)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LdlTree {
|
||||
Branch(Polynomial<Complex64>, Box<LdlTree>, Box<LdlTree>),
|
||||
Leaf([Complex64; 2]),
|
||||
}
|
||||
|
||||
/// Computes the LDL Tree of G. Corresponds to Algorithm 9 of the specification [1, p.37].
|
||||
/// The argument is a 2x2 matrix of polynomials, given in FFT form.
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn ffldl(gram_matrix: [Polynomial<Complex64>; 4]) -> LdlTree {
|
||||
let n = gram_matrix[0].coefficients.len();
|
||||
let (l, d) = ldl(gram_matrix);
|
||||
|
||||
if n > 2 {
|
||||
let (d00, d01) = d[0].split_fft();
|
||||
let (d10, d11) = d[3].split_fft();
|
||||
let g0 = [d00.clone(), d01.clone(), d01.map(|c| c.conj()), d00];
|
||||
let g1 = [d10.clone(), d11.clone(), d11.map(|c| c.conj()), d10];
|
||||
LdlTree::Branch(l[2].clone(), Box::new(ffldl(g0)), Box::new(ffldl(g1)))
|
||||
} else {
|
||||
LdlTree::Branch(
|
||||
l[2].clone(),
|
||||
Box::new(LdlTree::Leaf(d[0].clone().coefficients.try_into().unwrap())),
|
||||
Box::new(LdlTree::Leaf(d[3].clone().coefficients.try_into().unwrap())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalizes the leaves of an LDL tree using a given normalization value `sigma`.
|
||||
pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
|
||||
match tree {
|
||||
LdlTree::Branch(_ell, left, right) => {
|
||||
normalize_tree(left, sigma);
|
||||
normalize_tree(right, sigma);
|
||||
},
|
||||
LdlTree::Leaf(vector) => {
|
||||
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
|
||||
vector[1] = Complex64::zero();
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Samples short polynomials using a Falcon tree. Algorithm 11 from the spec [1, p.40].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn ffsampling<R: Rng>(
|
||||
t: &(Polynomial<Complex64>, Polynomial<Complex64>),
|
||||
tree: &LdlTree,
|
||||
mut rng: &mut R,
|
||||
) -> (Polynomial<Complex64>, Polynomial<Complex64>) {
|
||||
match tree {
|
||||
LdlTree::Branch(ell, left, right) => {
|
||||
let bold_t1 = t.1.split_fft();
|
||||
let bold_z1 = ffsampling(&bold_t1, right, rng);
|
||||
let z1 = Polynomial::<Complex64>::merge_fft(&bold_z1.0, &bold_z1.1);
|
||||
|
||||
// t0' = t0 + (t1 - z1) * l
|
||||
let t0_prime = t.0.clone() + (t.1.clone() - z1.clone()).hadamard_mul(ell);
|
||||
|
||||
let bold_t0 = t0_prime.split_fft();
|
||||
let bold_z0 = ffsampling(&bold_t0, left, rng);
|
||||
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
|
||||
|
||||
(z0, z1)
|
||||
},
|
||||
LdlTree::Leaf(value) => {
|
||||
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
|
||||
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
|
||||
(
|
||||
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
|
||||
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
1919
src/dsa/rpo_falcon512/math/fft.rs
Normal file
1919
src/dsa/rpo_falcon512/math/fft.rs
Normal file
File diff suppressed because it is too large
Load diff
174
src/dsa/rpo_falcon512/math/field.rs
Normal file
174
src/dsa/rpo_falcon512/math/field.rs
Normal file
|
@ -0,0 +1,174 @@
|
|||
use alloc::string::String;
|
||||
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use num::{One, Zero};
|
||||
|
||||
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub struct FalconFelt(u32);
|
||||
|
||||
impl FalconFelt {
|
||||
pub const fn new(value: i16) -> Self {
|
||||
let gtz_bool = value >= 0;
|
||||
let gtz_int = gtz_bool as i16;
|
||||
let gtz_sign = gtz_int - ((!gtz_bool) as i16);
|
||||
let reduced = gtz_sign * (gtz_sign * value) % MODULUS;
|
||||
let canonical_representative = (reduced + MODULUS * (1 - gtz_int)) as u32;
|
||||
FalconFelt(canonical_representative)
|
||||
}
|
||||
|
||||
pub const fn value(&self) -> i16 {
|
||||
self.0 as i16
|
||||
}
|
||||
|
||||
pub fn balanced_value(&self) -> i16 {
|
||||
let value = self.value();
|
||||
let g = (value > ((MODULUS) / 2)) as i16;
|
||||
value - (MODULUS) * g
|
||||
}
|
||||
|
||||
pub const fn multiply(&self, other: Self) -> Self {
|
||||
FalconFelt((self.0 * other.0) % MODULUS as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for FalconFelt {
|
||||
type Output = Self;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let (s, _) = self.0.overflowing_add(rhs.0);
|
||||
let (d, n) = s.overflowing_sub(MODULUS as u32);
|
||||
let (r, _) = d.overflowing_add(MODULUS as u32 * (n as u32));
|
||||
FalconFelt(r)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for FalconFelt {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for FalconFelt {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + -rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for FalconFelt {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for FalconFelt {
|
||||
type Output = FalconFelt;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let is_nonzero = self.0 != 0;
|
||||
let r = MODULUS as u32 - self.0;
|
||||
FalconFelt(r * (is_nonzero as u32))
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for FalconFelt {
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
FalconFelt((self.0 * rhs.0) % MODULUS as u32)
|
||||
}
|
||||
|
||||
type Output = Self;
|
||||
}
|
||||
|
||||
impl MulAssign for FalconFelt {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Div for FalconFelt {
|
||||
type Output = FalconFelt;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
self * rhs.inverse_or_zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign for FalconFelt {
|
||||
fn div_assign(&mut self, rhs: Self) {
|
||||
*self = *self / rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for FalconFelt {
|
||||
fn zero() -> Self {
|
||||
FalconFelt::new(0)
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl One for FalconFelt {
|
||||
fn one() -> Self {
|
||||
FalconFelt::new(1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for FalconFelt {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
// q-2 = 0b10 11 11 11 11 11 11
|
||||
let two = self.multiply(self);
|
||||
let three = two.multiply(self);
|
||||
let six = three.multiply(three);
|
||||
let twelve = six.multiply(six);
|
||||
let fifteen = twelve.multiply(three);
|
||||
let thirty = fifteen.multiply(fifteen);
|
||||
let sixty = thirty.multiply(thirty);
|
||||
let sixty_three = sixty.multiply(three);
|
||||
|
||||
let sixty_three_sq = sixty_three.multiply(sixty_three);
|
||||
let sixty_three_qu = sixty_three_sq.multiply(sixty_three_sq);
|
||||
let sixty_three_oc = sixty_three_qu.multiply(sixty_three_qu);
|
||||
let sixty_three_hx = sixty_three_oc.multiply(sixty_three_oc);
|
||||
let sixty_three_tt = sixty_three_hx.multiply(sixty_three_hx);
|
||||
let sixty_three_sf = sixty_three_tt.multiply(sixty_three_tt);
|
||||
|
||||
let all_ones = sixty_three_sf.multiply(sixty_three);
|
||||
let two_e_twelve = all_ones.multiply(self);
|
||||
let two_e_thirteen = two_e_twelve.multiply(two_e_twelve);
|
||||
|
||||
two_e_thirteen.multiply(all_ones)
|
||||
}
|
||||
}
|
||||
|
||||
impl CyclotomicFourier for FalconFelt {
|
||||
fn primitive_root_of_unity(n: usize) -> Self {
|
||||
let log2n = n.ilog2();
|
||||
assert!(log2n <= 12);
|
||||
// and 1331 is a twelfth root of unity
|
||||
let mut a = FalconFelt::new(1331);
|
||||
let num_squarings = 12 - n.ilog2();
|
||||
for _ in 0..num_squarings {
|
||||
a *= a;
|
||||
}
|
||||
a
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for FalconFelt {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: u32) -> Result<Self, Self::Error> {
|
||||
if value >= MODULUS as u32 {
|
||||
Err(format!("value {value} is greater than or equal to the field modulus {MODULUS}"))
|
||||
} else {
|
||||
Ok(FalconFelt::new(value as i16))
|
||||
}
|
||||
}
|
||||
}
|
322
src/dsa/rpo_falcon512/math/mod.rs
Normal file
322
src/dsa/rpo_falcon512/math/mod.rs
Normal file
|
@ -0,0 +1,322 @@
|
|||
//! Contains different structs and methods related to the Falcon DSA.
|
||||
//!
|
||||
//! It uses and acknowledges the work in:
|
||||
//!
|
||||
//! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas
|
||||
//! Pornin.
|
||||
//! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec.
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use core::ops::MulAssign;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use num::{BigInt, FromPrimitive, One, Zero};
|
||||
use num_complex::Complex64;
|
||||
use rand::Rng;
|
||||
|
||||
use super::MODULUS;
|
||||
|
||||
mod fft;
|
||||
pub use fft::{CyclotomicFourier, FastFft};
|
||||
|
||||
mod field;
|
||||
pub use field::FalconFelt;
|
||||
|
||||
mod ffsampling;
|
||||
pub use ffsampling::{ffldl, ffsampling, gram, normalize_tree, LdlTree};
|
||||
|
||||
mod samplerz;
|
||||
use self::samplerz::sampler_z;
|
||||
|
||||
mod polynomial;
|
||||
pub use polynomial::Polynomial;
|
||||
|
||||
pub trait Inverse: Copy + Zero + MulAssign + One {
|
||||
/// Gets the inverse of a, or zero if it is zero.
|
||||
fn inverse_or_zero(self) -> Self;
|
||||
|
||||
/// Gets the inverses of a batch of elements, and skip over any that are zero.
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
let mut acc = Self::one();
|
||||
let mut rp: Vec<Self> = Vec::with_capacity(batch.len());
|
||||
for batch_item in batch {
|
||||
if !batch_item.is_zero() {
|
||||
rp.push(acc);
|
||||
acc = *batch_item * acc;
|
||||
} else {
|
||||
rp.push(Self::zero());
|
||||
}
|
||||
}
|
||||
let mut inv = Self::inverse_or_zero(acc);
|
||||
for i in (0..batch.len()).rev() {
|
||||
if !batch[i].is_zero() {
|
||||
rp[i] *= inv;
|
||||
inv *= batch[i];
|
||||
}
|
||||
}
|
||||
rp
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for Complex64 {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
let modulus = self.re * self.re + self.im * self.im;
|
||||
Complex64::new(self.re / modulus, -self.im / modulus)
|
||||
}
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
batch.iter().map(|&c| Complex64::new(1.0, 0.0) / c).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for f64 {
|
||||
fn inverse_or_zero(self) -> Self {
|
||||
1.0 / self
|
||||
}
|
||||
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
|
||||
batch.iter().map(|&c| 1.0 / c).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Samples 4 small polynomials f, g, F, G such that f * G - g * F = q mod (X^n + 1).
|
||||
/// Algorithm 5 (NTRUgen) of the documentation [1, p.34].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub(crate) fn ntru_gen<R: Rng>(n: usize, rng: &mut R) -> [Polynomial<i16>; 4] {
|
||||
loop {
|
||||
let f = gen_poly(n, rng);
|
||||
let g = gen_poly(n, rng);
|
||||
let f_ntt = f.map(|&i| FalconFelt::new(i)).fft();
|
||||
if f_ntt.coefficients.iter().any(|e| e.is_zero()) {
|
||||
continue;
|
||||
}
|
||||
let gamma = gram_schmidt_norm_squared(&f, &g);
|
||||
if gamma > 1.3689f64 * (MODULUS as f64) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((capital_f, capital_g)) =
|
||||
ntru_solve(&f.map(|&i| i.into()), &g.map(|&i| i.into()))
|
||||
{
|
||||
return [
|
||||
g,
|
||||
-f,
|
||||
capital_g.map(|i| i.try_into().unwrap()),
|
||||
-capital_f.map(|i| i.try_into().unwrap()),
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solves the NTRU equation. Given f, g in ZZ[X], find F, G in ZZ[X] such that:
|
||||
///
|
||||
/// f G - g F = q mod (X^n + 1)
|
||||
///
|
||||
/// Algorithm 6 of the specification [1, p.35].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn ntru_solve(
|
||||
f: &Polynomial<BigInt>,
|
||||
g: &Polynomial<BigInt>,
|
||||
) -> Option<(Polynomial<BigInt>, Polynomial<BigInt>)> {
|
||||
let n = f.coefficients.len();
|
||||
if n == 1 {
|
||||
let (gcd, u, v) = xgcd(&f.coefficients[0], &g.coefficients[0]);
|
||||
if gcd != BigInt::one() {
|
||||
return None;
|
||||
}
|
||||
return Some((
|
||||
(Polynomial::new(vec![-v * BigInt::from_u32(MODULUS as u32).unwrap()])),
|
||||
Polynomial::new(vec![u * BigInt::from_u32(MODULUS as u32).unwrap()]),
|
||||
));
|
||||
}
|
||||
|
||||
let f_prime = f.field_norm();
|
||||
let g_prime = g.field_norm();
|
||||
|
||||
let (capital_f_prime, capital_g_prime) = ntru_solve(&f_prime, &g_prime)?;
|
||||
let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
|
||||
let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
|
||||
|
||||
let f_minx = f.galois_adjoint();
|
||||
let g_minx = g.galois_adjoint();
|
||||
|
||||
let mut capital_f = (capital_f_prime_xsq.karatsuba(&g_minx)).reduce_by_cyclotomic(n);
|
||||
let mut capital_g = (capital_g_prime_xsq.karatsuba(&f_minx)).reduce_by_cyclotomic(n);
|
||||
|
||||
match babai_reduce(f, g, &mut capital_f, &mut capital_g) {
|
||||
Ok(_) => Some((capital_f, capital_g)),
|
||||
Err(_e) => {
|
||||
#[cfg(test)]
|
||||
{
|
||||
panic!("{}", _e);
|
||||
}
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
None
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a polynomial of degree at most n-1 whose coefficients are distributed according
|
||||
/// to a discrete Gaussian with mu = 0 and sigma = 1.17 * sqrt(Q / (2n)).
|
||||
fn gen_poly<R: Rng>(n: usize, rng: &mut R) -> Polynomial<i16> {
|
||||
let mu = 0.0;
|
||||
let sigma_star = 1.43300980528773;
|
||||
Polynomial {
|
||||
coefficients: (0..4096)
|
||||
.map(|_| sampler_z(mu, sigma_star, sigma_star - 0.001, rng))
|
||||
.collect::<Vec<i16>>()
|
||||
.chunks(4096 / n)
|
||||
.map(|ch| ch.iter().sum())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the Gram-Schmidt norm of B = [[g, -f], [G, -F]] from f and g.
|
||||
/// Corresponds to line 9 in algorithm 5 of the spec [1, p.34]
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn gram_schmidt_norm_squared(f: &Polynomial<i16>, g: &Polynomial<i16>) -> f64 {
|
||||
let n = f.coefficients.len();
|
||||
let norm_f_squared = f.l2_norm_squared();
|
||||
let norm_g_squared = g.l2_norm_squared();
|
||||
let gamma1 = norm_f_squared + norm_g_squared;
|
||||
|
||||
let f_fft = f.map(|i| Complex64::new(*i as f64, 0.0)).fft();
|
||||
let g_fft = g.map(|i| Complex64::new(*i as f64, 0.0)).fft();
|
||||
let f_adj_fft = f_fft.map(|c| c.conj());
|
||||
let g_adj_fft = g_fft.map(|c| c.conj());
|
||||
let ffgg_fft = f_fft.hadamard_mul(&f_adj_fft) + g_fft.hadamard_mul(&g_adj_fft);
|
||||
let ffgg_fft_inverse = ffgg_fft.hadamard_inv();
|
||||
let qf_over_ffgg_fft = f_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
|
||||
let qg_over_ffgg_fft = g_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
|
||||
let norm_f_over_ffgg_squared =
|
||||
qf_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
|
||||
let norm_g_over_ffgg_squared =
|
||||
qg_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
|
||||
|
||||
let gamma2 = norm_f_over_ffgg_squared + norm_g_over_ffgg_squared;
|
||||
|
||||
f64::max(gamma1, gamma2)
|
||||
}
|
||||
|
||||
/// Reduces the vector (F,G) relative to (f,g). This method follows the python implementation [1].
|
||||
/// Note that this algorithm can end up in an infinite loop. (It's one of the things the author
|
||||
/// would like to fix.) When this happens, control returns an error (hence the return type) and
|
||||
/// generates another keypair with fresh randomness.
|
||||
///
|
||||
/// Algorithm 7 in the spec [2, p.35]
|
||||
///
|
||||
/// [1]: https://github.com/tprest/falcon.py
|
||||
///
|
||||
/// [2]: https://falcon-sign.info/falcon.pdf
|
||||
fn babai_reduce(
|
||||
f: &Polynomial<BigInt>,
|
||||
g: &Polynomial<BigInt>,
|
||||
capital_f: &mut Polynomial<BigInt>,
|
||||
capital_g: &mut Polynomial<BigInt>,
|
||||
) -> Result<(), String> {
|
||||
let bitsize = |bi: &BigInt| (bi.bits() + 7) & (u64::MAX ^ 7);
|
||||
let n = f.coefficients.len();
|
||||
let size = [
|
||||
f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
53,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap();
|
||||
let shift = (size as i64) - 53;
|
||||
let f_adjusted = f
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
let g_adjusted = g
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
|
||||
let f_star_adjusted = f_adjusted.map(|c| c.conj());
|
||||
let g_star_adjusted = g_adjusted.map(|c| c.conj());
|
||||
let denominator_fft =
|
||||
f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
|
||||
|
||||
let mut counter = 0;
|
||||
loop {
|
||||
let capital_size = [
|
||||
capital_f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
capital_g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
|
||||
53,
|
||||
]
|
||||
.into_iter()
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
if capital_size < size {
|
||||
break;
|
||||
}
|
||||
let capital_shift = (capital_size as i64) - 53;
|
||||
let capital_f_adjusted = capital_f
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
let capital_g_adjusted = capital_g
|
||||
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
|
||||
.fft();
|
||||
|
||||
let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
|
||||
+ capital_g_adjusted.hadamard_mul(&g_star_adjusted);
|
||||
let quotient = numerator.hadamard_div(&denominator_fft).ifft();
|
||||
|
||||
let k = quotient.map(|f| Into::<BigInt>::into(f.re.round() as i64));
|
||||
|
||||
if k.is_zero() {
|
||||
break;
|
||||
}
|
||||
let kf = (k.clone().karatsuba(f))
|
||||
.reduce_by_cyclotomic(n)
|
||||
.map(|bi| bi << (capital_size - size));
|
||||
let kg = (k.clone().karatsuba(g))
|
||||
.reduce_by_cyclotomic(n)
|
||||
.map(|bi| bi << (capital_size - size));
|
||||
*capital_f -= kf;
|
||||
*capital_g -= kg;
|
||||
|
||||
counter += 1;
|
||||
if counter > 1000 {
|
||||
// If we get here, that means that (with high likelihood) we are in an
|
||||
// infinite loop. We know it happens from time to time -- seldomly, but it
|
||||
// does. It would be nice to fix that! But in order to fix it we need to be
|
||||
// able to reproduce it, and for that we need test vectors. So print them
|
||||
// and hope that one day they circle back to the implementor.
|
||||
return Err(format!("Encountered infinite loop in babai_reduce of falcon-rust.\n\\
|
||||
Please help the developer(s) fix it! You can do this by sending them the inputs to the function that caused the behavior:\n\\
|
||||
f: {:?}\n\\
|
||||
g: {:?}\n\\
|
||||
capital_f: {:?}\n\\
|
||||
capital_g: {:?}\n", f.coefficients, g.coefficients, capital_f.coefficients, capital_g.coefficients));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extended Euclidean algorithm for computing the greatest common divisor (g) and
|
||||
/// Bézout coefficients (u, v) for the relation
|
||||
///
|
||||
/// $$ u a + v b = g . $$
|
||||
///
|
||||
/// Implementation adapted from Wikipedia [1].
|
||||
///
|
||||
/// [1]: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
|
||||
fn xgcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
|
||||
let (mut old_r, mut r) = (a.clone(), b.clone());
|
||||
let (mut old_s, mut s) = (BigInt::one(), BigInt::zero());
|
||||
let (mut old_t, mut t) = (BigInt::zero(), BigInt::one());
|
||||
|
||||
while r != BigInt::zero() {
|
||||
let quotient = old_r.clone() / r.clone();
|
||||
(old_r, r) = (r.clone(), old_r.clone() - quotient.clone() * r);
|
||||
(old_s, s) = (s.clone(), old_s.clone() - quotient.clone() * s);
|
||||
(old_t, t) = (t.clone(), old_t.clone() - quotient * t);
|
||||
}
|
||||
|
||||
(old_r, old_s, old_t)
|
||||
}
|
622
src/dsa/rpo_falcon512/math/polynomial.rs
Normal file
622
src/dsa/rpo_falcon512/math/polynomial.rs
Normal file
|
@ -0,0 +1,622 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::{
|
||||
default::Default,
|
||||
fmt::Debug,
|
||||
ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
|
||||
};
|
||||
|
||||
use num::{One, Zero};
|
||||
|
||||
use super::{field::FalconFelt, Inverse};
|
||||
use crate::{
|
||||
dsa::rpo_falcon512::{MODULUS, N},
|
||||
Felt,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Polynomial<F> {
|
||||
pub coefficients: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F> Polynomial<F>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
pub fn new(coefficients: Vec<F>) -> Self {
|
||||
Self { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone + Inverse,
|
||||
> Polynomial<F>
|
||||
{
|
||||
pub fn hadamard_mul(&self, other: &Self) -> Self {
|
||||
Polynomial::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.zip(other.coefficients.iter())
|
||||
.map(|(a, b)| *a * *b)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
pub fn hadamard_div(&self, other: &Self) -> Self {
|
||||
let other_coefficients_inverse = F::batch_inverse_or_zero(&other.coefficients);
|
||||
Polynomial::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.zip(other_coefficients_inverse.iter())
|
||||
.map(|(a, b)| *a * *b)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn hadamard_inv(&self) -> Self {
|
||||
let coefficients_inverse = F::batch_inverse_or_zero(&self.coefficients);
|
||||
Polynomial::new(coefficients_inverse)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + PartialEq + Clone> Polynomial<F> {
|
||||
pub fn degree(&self) -> Option<usize> {
|
||||
if self.coefficients.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let mut max_index = self.coefficients.len() - 1;
|
||||
while self.coefficients[max_index] == F::zero() {
|
||||
if let Some(new_index) = max_index.checked_sub(1) {
|
||||
max_index = new_index;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(max_index)
|
||||
}
|
||||
|
||||
pub fn lc(&self) -> F {
|
||||
match self.degree() {
|
||||
Some(non_negative_degree) => self.coefficients[non_negative_degree].clone(),
|
||||
None => F::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The following implementations are specific to cyclotomic polynomial rings,
|
||||
/// i.e., F\[ X \] / <X^n + 1>, and are used extensively in Falcon.
|
||||
impl<
|
||||
F: One
|
||||
+ Zero
|
||||
+ Clone
|
||||
+ Neg<Output = F>
|
||||
+ MulAssign
|
||||
+ AddAssign
|
||||
+ Div<Output = F>
|
||||
+ Sub<Output = F>
|
||||
+ PartialEq,
|
||||
> Polynomial<F>
|
||||
{
|
||||
/// Reduce the polynomial by X^n + 1.
|
||||
pub fn reduce_by_cyclotomic(&self, n: usize) -> Self {
|
||||
let mut coefficients = vec![F::zero(); n];
|
||||
let mut sign = -F::one();
|
||||
for (i, c) in self.coefficients.iter().cloned().enumerate() {
|
||||
if i % n == 0 {
|
||||
sign *= -F::one();
|
||||
}
|
||||
coefficients[i % n] += sign.clone() * c;
|
||||
}
|
||||
Polynomial::new(coefficients)
|
||||
}
|
||||
|
||||
/// Computes the field norm of the polynomial as an element of the cyclotomic ring
|
||||
/// F\[ X \] / <X^n + 1 > relative to one of half the size, i.e., F\[ X \] / <X^(n/2) + 1> .
|
||||
///
|
||||
/// Corresponds to formula 3.25 in the spec [1, p.30].
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
pub fn field_norm(&self) -> Self {
|
||||
let n = self.coefficients.len();
|
||||
let mut f0_coefficients = vec![F::zero(); n / 2];
|
||||
let mut f1_coefficients = vec![F::zero(); n / 2];
|
||||
for i in 0..n / 2 {
|
||||
f0_coefficients[i] = self.coefficients[2 * i].clone();
|
||||
f1_coefficients[i] = self.coefficients[2 * i + 1].clone();
|
||||
}
|
||||
let f0 = Polynomial::new(f0_coefficients);
|
||||
let f1 = Polynomial::new(f1_coefficients);
|
||||
let f0_squared = (f0.clone() * f0).reduce_by_cyclotomic(n / 2);
|
||||
let f1_squared = (f1.clone() * f1).reduce_by_cyclotomic(n / 2);
|
||||
let x = Polynomial::new(vec![F::zero(), F::one()]);
|
||||
f0_squared - (x * f1_squared).reduce_by_cyclotomic(n / 2)
|
||||
}
|
||||
|
||||
/// Lifts an element from a cyclotomic polynomial ring to one of double the size.
|
||||
pub fn lift_next_cyclotomic(&self) -> Self {
|
||||
let n = self.coefficients.len();
|
||||
let mut coefficients = vec![F::zero(); n * 2];
|
||||
for i in 0..n {
|
||||
coefficients[2 * i] = self.coefficients[i].clone();
|
||||
}
|
||||
Self::new(coefficients)
|
||||
}
|
||||
|
||||
/// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 >
|
||||
/// , which corresponds to f(x^2).
|
||||
pub fn galois_adjoint(&self) -> Self {
|
||||
Self::new(
|
||||
self.coefficients
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, c)| if i % 2 == 0 { c.clone() } else { c.clone().neg() })
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Clone + Into<f64>> Polynomial<F> {
|
||||
pub(crate) fn l2_norm_squared(&self) -> f64 {
|
||||
self.coefficients
|
||||
.iter()
|
||||
.map(|i| Into::<f64>::into(i.clone()))
|
||||
.map(|i| i * i)
|
||||
.sum::<f64>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> PartialEq for Polynomial<F>
|
||||
where
|
||||
F: Zero + PartialEq + Clone + AddAssign,
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.is_zero() && other.is_zero() {
|
||||
true
|
||||
} else if self.is_zero() || other.is_zero() {
|
||||
false
|
||||
} else {
|
||||
let self_degree = self.degree().unwrap();
|
||||
let other_degree = other.degree().unwrap();
|
||||
self.coefficients[0..=self_degree] == other.coefficients[0..=other_degree]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Eq for Polynomial<F> where F: Zero + PartialEq + Clone + AddAssign {}
|
||||
|
||||
impl<F> Add for &Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
let mut coefficients = self.coefficients.clone();
|
||||
for (i, c) in rhs.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
coefficients
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
coefficients
|
||||
};
|
||||
Self::Output { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Add for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
let mut coefficients = self.coefficients.clone();
|
||||
for (i, c) in rhs.coefficients.into_iter().enumerate() {
|
||||
coefficients[i] += c;
|
||||
}
|
||||
coefficients
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.into_iter().enumerate() {
|
||||
coefficients[i] += c;
|
||||
}
|
||||
coefficients
|
||||
};
|
||||
Self::Output { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> AddAssign for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + AddAssign + Clone,
|
||||
{
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
if self.coefficients.len() >= rhs.coefficients.len() {
|
||||
for (i, c) in rhs.coefficients.into_iter().enumerate() {
|
||||
self.coefficients[i] += c;
|
||||
}
|
||||
} else {
|
||||
let mut coefficients = rhs.coefficients.clone();
|
||||
for (i, c) in self.coefficients.iter().enumerate() {
|
||||
coefficients[i] += c.clone();
|
||||
}
|
||||
self.coefficients = coefficients;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Sub for &Polynomial<F>
|
||||
where
|
||||
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + &(-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Sub for Polynomial<F>
|
||||
where
|
||||
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
self + (-rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> SubAssign for Polynomial<F>
|
||||
where
|
||||
F: Add<Output = F> + Neg<Output = F> + AddAssign + Clone + Sub<Output = F>,
|
||||
{
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
self.coefficients = self.clone().sub(rhs).coefficients;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Neg<Output = F> + Clone> Neg for &Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Neg<Output = F> + Clone> Neg for Polynomial<F> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output {
|
||||
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Mul for &Polynomial<F>
|
||||
where
|
||||
F: Add + AddAssign + Mul<Output = F> + Sub<Output = F> + Zero + PartialEq + Clone,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.is_zero() || other.is_zero() {
|
||||
return Polynomial::<F>::zero();
|
||||
}
|
||||
let mut coefficients =
|
||||
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
|
||||
for i in 0..self.coefficients.len() {
|
||||
for j in 0..other.coefficients.len() {
|
||||
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
|
||||
}
|
||||
}
|
||||
Polynomial { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Mul for Polynomial<F>
|
||||
where
|
||||
F: Add + AddAssign + Mul<Output = F> + Zero + PartialEq + Clone,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.is_zero() || other.is_zero() {
|
||||
return Self::zero();
|
||||
}
|
||||
let mut coefficients =
|
||||
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
|
||||
for i in 0..self.coefficients.len() {
|
||||
for j in 0..other.coefficients.len() {
|
||||
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
|
||||
}
|
||||
}
|
||||
Self { coefficients }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for &Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: F) -> Self::Output {
|
||||
Polynomial {
|
||||
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for Polynomial<F> {
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn mul(self, other: F) -> Self::Output {
|
||||
Polynomial {
|
||||
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone>
|
||||
Polynomial<F>
|
||||
{
|
||||
/// Multiply two polynomials using Karatsuba's divide-and-conquer algorithm.
|
||||
pub fn karatsuba(&self, other: &Self) -> Self {
|
||||
Polynomial::new(vector_karatsuba(&self.coefficients, &other.coefficients))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> One for Polynomial<F>
|
||||
where
|
||||
F: Clone + One + PartialEq + Zero + AddAssign,
|
||||
{
|
||||
fn one() -> Self {
|
||||
Self { coefficients: vec![F::one()] }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Zero for Polynomial<F>
|
||||
where
|
||||
F: Zero + PartialEq + Clone + AddAssign,
|
||||
{
|
||||
fn zero() -> Self {
|
||||
Self { coefficients: vec![] }
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.degree().is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Zero + Clone> Polynomial<F> {
|
||||
pub fn shift(&self, shamt: usize) -> Self {
|
||||
Self {
|
||||
coefficients: [vec![F::zero(); shamt], self.coefficients.clone()].concat(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn constant(f: F) -> Self {
|
||||
Self { coefficients: vec![f] }
|
||||
}
|
||||
|
||||
pub fn map<G: Clone, C: FnMut(&F) -> G>(&self, closure: C) -> Polynomial<G> {
|
||||
Polynomial::<G>::new(self.coefficients.iter().map(closure).collect())
|
||||
}
|
||||
|
||||
pub fn fold<G, C: FnMut(G, &F) -> G + Clone>(&self, mut initial_value: G, closure: C) -> G {
|
||||
for c in self.coefficients.iter() {
|
||||
initial_value = (closure.clone())(initial_value, c);
|
||||
}
|
||||
initial_value
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Div<Polynomial<F>> for Polynomial<F>
|
||||
where
|
||||
F: Zero
|
||||
+ One
|
||||
+ PartialEq
|
||||
+ AddAssign
|
||||
+ Clone
|
||||
+ Mul<Output = F>
|
||||
+ MulAssign
|
||||
+ Div<Output = F>
|
||||
+ Neg<Output = F>
|
||||
+ Sub<Output = F>,
|
||||
{
|
||||
type Output = Polynomial<F>;
|
||||
|
||||
fn div(self, denominator: Self) -> Self::Output {
|
||||
if denominator.is_zero() {
|
||||
panic!();
|
||||
}
|
||||
if self.is_zero() {
|
||||
Self::zero();
|
||||
}
|
||||
let mut remainder = self.clone();
|
||||
let mut quotient = Polynomial::<F>::zero();
|
||||
while remainder.degree().unwrap() >= denominator.degree().unwrap() {
|
||||
let shift = remainder.degree().unwrap() - denominator.degree().unwrap();
|
||||
let quotient_coefficient = remainder.lc() / denominator.lc();
|
||||
let monomial = Self::constant(quotient_coefficient).shift(shift);
|
||||
quotient += monomial.clone();
|
||||
remainder -= monomial * denominator.clone();
|
||||
if remainder.is_zero() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
quotient
|
||||
}
|
||||
}
|
||||
|
||||
fn vector_karatsuba<
|
||||
F: Zero + AddAssign + Mul<Output = F> + Sub<Output = F> + Div<Output = F> + Clone,
|
||||
>(
|
||||
left: &[F],
|
||||
right: &[F],
|
||||
) -> Vec<F> {
|
||||
let n = left.len();
|
||||
if n <= 8 {
|
||||
let mut product = vec![F::zero(); left.len() + right.len() - 1];
|
||||
for (i, l) in left.iter().enumerate() {
|
||||
for (j, r) in right.iter().enumerate() {
|
||||
product[i + j] += l.clone() * r.clone();
|
||||
}
|
||||
}
|
||||
return product;
|
||||
}
|
||||
let n_over_2 = n / 2;
|
||||
let mut product = vec![F::zero(); 2 * n - 1];
|
||||
let left_lo = &left[0..n_over_2];
|
||||
let right_lo = &right[0..n_over_2];
|
||||
let left_hi = &left[n_over_2..];
|
||||
let right_hi = &right[n_over_2..];
|
||||
let left_sum: Vec<F> =
|
||||
left_lo.iter().zip(left_hi).map(|(a, b)| a.clone() + b.clone()).collect();
|
||||
let right_sum: Vec<F> =
|
||||
right_lo.iter().zip(right_hi).map(|(a, b)| a.clone() + b.clone()).collect();
|
||||
|
||||
let prod_lo = vector_karatsuba(left_lo, right_lo);
|
||||
let prod_hi = vector_karatsuba(left_hi, right_hi);
|
||||
let prod_mid: Vec<F> = vector_karatsuba(&left_sum, &right_sum)
|
||||
.iter()
|
||||
.zip(prod_lo.iter().zip(prod_hi.iter()))
|
||||
.map(|(s, (l, h))| s.clone() - (l.clone() + h.clone()))
|
||||
.collect();
|
||||
|
||||
for (i, l) in prod_lo.into_iter().enumerate() {
|
||||
product[i] = l;
|
||||
}
|
||||
for (i, m) in prod_mid.into_iter().enumerate() {
|
||||
product[i + n_over_2] += m;
|
||||
}
|
||||
for (i, h) in prod_hi.into_iter().enumerate() {
|
||||
product[i + n] += h
|
||||
}
|
||||
product
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for Polynomial<Felt> {
|
||||
fn from(item: Polynomial<FalconFelt>) -> Self {
|
||||
let res: Vec<Felt> =
|
||||
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Polynomial<FalconFelt>> for Polynomial<Felt> {
|
||||
fn from(item: &Polynomial<FalconFelt>) -> Self {
|
||||
let res: Vec<Felt> =
|
||||
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: Polynomial<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Polynomial<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: &Polynomial<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: Vec<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Vec<i16>> for Polynomial<FalconFelt> {
|
||||
fn from(item: &Vec<i16>) -> Self {
|
||||
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
|
||||
Polynomial::new(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl Polynomial<FalconFelt> {
|
||||
pub fn norm_squared(&self) -> u64 {
|
||||
self.coefficients
|
||||
.iter()
|
||||
.map(|&i| i.balanced_value() as i64)
|
||||
.map(|i| (i * i) as u64)
|
||||
.sum::<u64>()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the coefficients of this polynomial as field elements.
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.coefficients.iter().map(|&a| Felt::from(a.value() as u16)).collect()
|
||||
}
|
||||
|
||||
// POLYNOMIAL OPERATIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
|
||||
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
|
||||
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
|
||||
/// than the Miden prime.
|
||||
///
|
||||
/// Note that this multiplication is not over Z_p\[x\]/(phi).
|
||||
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
|
||||
let mut c = [0; 2 * N];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
c[i + j] += a.coefficients[i].value() as u64 * b.coefficients[j].value() as u64;
|
||||
}
|
||||
}
|
||||
|
||||
c
|
||||
}
|
||||
|
||||
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
|
||||
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
|
||||
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
|
||||
let mut c = [FalconFelt::zero(); N];
|
||||
let modulus = MODULUS as u16;
|
||||
for i in 0..N {
|
||||
let ai = a[N + i] % modulus as u64;
|
||||
let neg_ai = (modulus - ai as u16) % modulus;
|
||||
|
||||
let bi = (a[i] % modulus as u64) as u16;
|
||||
c[i] = FalconFelt::new(((neg_ai + bi) % modulus) as i16);
|
||||
}
|
||||
|
||||
Self::new(c.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{FalconFelt, Polynomial, N};
|
||||
|
||||
#[test]
|
||||
fn test_negacyclic_reduction() {
|
||||
let coef1: [u8; N] = rand_utils::rand_array();
|
||||
let coef2: [u8; N] = rand_utils::rand_array();
|
||||
|
||||
let poly1 = Polynomial::new(coef1.iter().map(|&a| FalconFelt::new(a as i16)).collect());
|
||||
let poly2 = Polynomial::new(coef2.iter().map(|&a| FalconFelt::new(a as i16)).collect());
|
||||
let prod = poly1.clone() * poly2.clone();
|
||||
|
||||
assert_eq!(
|
||||
prod.reduce_by_cyclotomic(N),
|
||||
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
|
||||
);
|
||||
}
|
||||
}
|
299
src/dsa/rpo_falcon512/math/samplerz.rs
Normal file
299
src/dsa/rpo_falcon512/math/samplerz.rs
Normal file
|
@ -0,0 +1,299 @@
|
|||
use core::f64::consts::LN_2;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num::Float;
|
||||
use rand::Rng;
|
||||
|
||||
/// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to
|
||||
/// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation
|
||||
/// equal to sigma_max.
|
||||
fn base_sampler(bytes: [u8; 9]) -> i16 {
|
||||
const RCDT: [u128; 18] = [
|
||||
3024686241123004913666,
|
||||
1564742784480091954050,
|
||||
636254429462080897535,
|
||||
199560484645026482916,
|
||||
47667343854657281903,
|
||||
8595902006365044063,
|
||||
1163297957344668388,
|
||||
117656387352093658,
|
||||
8867391802663976,
|
||||
496969357462633,
|
||||
20680885154299,
|
||||
638331848991,
|
||||
14602316184,
|
||||
247426747,
|
||||
3104126,
|
||||
28824,
|
||||
198,
|
||||
1,
|
||||
];
|
||||
let u = u128::from_be_bytes([vec![0u8; 7], bytes.to_vec()].concat().try_into().unwrap());
|
||||
RCDT.into_iter().filter(|r| u < *r).count() as i16
|
||||
}
|
||||
|
||||
/// Computes an integer approximation of 2^63 * ccs * exp(-x).
|
||||
fn approx_exp(x: f64, ccs: f64) -> u64 {
|
||||
// The constants C are used to approximate exp(-x); these
|
||||
// constants are taken from FACCT (up to a scaling factor
|
||||
// of 2^63):
|
||||
// https://eprint.iacr.org/2018/1234
|
||||
// https://github.com/raykzhao/gaussian
|
||||
const C: [u64; 13] = [
|
||||
0x00000004741183a3u64,
|
||||
0x00000036548cfc06u64,
|
||||
0x0000024fdcbf140au64,
|
||||
0x0000171d939de045u64,
|
||||
0x0000d00cf58f6f84u64,
|
||||
0x000680681cf796e3u64,
|
||||
0x002d82d8305b0feau64,
|
||||
0x011111110e066fd0u64,
|
||||
0x0555555555070f00u64,
|
||||
0x155555555581ff00u64,
|
||||
0x400000000002b400u64,
|
||||
0x7fffffffffff4800u64,
|
||||
0x8000000000000000u64,
|
||||
];
|
||||
|
||||
let mut z: u64;
|
||||
let mut y: u64;
|
||||
let twoe63 = 1u64 << 63;
|
||||
|
||||
y = C[0];
|
||||
z = f64::floor(x * (twoe63 as f64)) as u64;
|
||||
for cu in C.iter().skip(1) {
|
||||
let zy = (z as u128) * (y as u128);
|
||||
y = cu - ((zy >> 63) as u64);
|
||||
}
|
||||
|
||||
z = f64::floor((twoe63 as f64) * ccs) as u64;
|
||||
|
||||
(((z as u128) * (y as u128)) >> 63) as u64
|
||||
}
|
||||
|
||||
/// A random bool that is true with probability ≈ ccs · exp(-x).
|
||||
fn ber_exp(x: f64, ccs: f64, random_bytes: [u8; 7]) -> bool {
|
||||
// 0.69314718055994530941 = ln(2)
|
||||
let s = f64::floor(x / LN_2) as usize;
|
||||
let r = x - LN_2 * (s as f64);
|
||||
let shamt = usize::min(s, 63);
|
||||
let z = ((((approx_exp(r, ccs) as u128) << 1) - 1) >> shamt) as u64;
|
||||
let mut w = 0i16;
|
||||
for (index, i) in (0..64).step_by(8).rev().enumerate() {
|
||||
let byte = random_bytes[index];
|
||||
w = (byte as i16) - (((z >> i) & 0xff) as i16);
|
||||
if w != 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
w < 0
|
||||
}
|
||||
|
||||
/// Samples an integer from the Gaussian distribution with given mean (mu) and standard deviation
|
||||
/// (sigma).
|
||||
pub(crate) fn sampler_z<R: Rng>(mu: f64, sigma: f64, sigma_min: f64, rng: &mut R) -> i16 {
|
||||
const SIGMA_MAX: f64 = 1.8205;
|
||||
const INV_2SIGMA_MAX_SQ: f64 = 1f64 / (2f64 * SIGMA_MAX * SIGMA_MAX);
|
||||
let isigma = 1f64 / sigma;
|
||||
let dss = 0.5f64 * isigma * isigma;
|
||||
let s = f64::floor(mu);
|
||||
let r = mu - s;
|
||||
let ccs = sigma_min * isigma;
|
||||
loop {
|
||||
let z0 = base_sampler(rng.gen());
|
||||
let random_byte: u8 = rng.gen();
|
||||
let b = (random_byte & 1) as i16;
|
||||
let z = b + ((b << 1) - 1) * z0;
|
||||
let zf_min_r = (z as f64) - r;
|
||||
// x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
|
||||
let x = zf_min_r * zf_min_r * dss - (z0 * z0) as f64 * INV_2SIGMA_MAX_SQ;
|
||||
if ber_exp(x, ccs, rng.gen()) {
|
||||
return z + (s as i16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod test {
|
||||
use alloc::vec::Vec;
|
||||
use std::{thread::sleep, time::Duration};
|
||||
|
||||
use rand::RngCore;
|
||||
|
||||
use super::{approx_exp, ber_exp, sampler_z};
|
||||
|
||||
/// RNG used only for testing purposes, whereby the produced
|
||||
/// string of random bytes is equal to the one it is initialized
|
||||
/// with. Whatever you do, do not use this RNG in production.
|
||||
struct UnsafeBufferRng {
|
||||
buffer: Vec<u8>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl UnsafeBufferRng {
|
||||
fn new(buffer: &[u8]) -> Self {
|
||||
Self { buffer: buffer.to_vec(), index: 0 }
|
||||
}
|
||||
|
||||
fn next(&mut self) -> u8 {
|
||||
if self.buffer.len() <= self.index {
|
||||
// panic!("Ran out of buffer.");
|
||||
sleep(Duration::from_millis(10));
|
||||
0u8
|
||||
} else {
|
||||
let return_value = self.buffer[self.index];
|
||||
self.index += 1;
|
||||
return_value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RngCore for UnsafeBufferRng {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
// let bytes: [u8; 4] = (0..4)
|
||||
// .map(|_| self.next())
|
||||
// .collect_vec()
|
||||
// .try_into()
|
||||
// .unwrap();
|
||||
// u32::from_be_bytes(bytes)
|
||||
u32::from_le_bytes([self.next(), 0, 0, 0])
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
// let bytes: [u8; 8] = (0..8)
|
||||
// .map(|_| self.next())
|
||||
// .collect_vec()
|
||||
// .try_into()
|
||||
// .unwrap();
|
||||
// u64::from_be_bytes(bytes)
|
||||
u64::from_le_bytes([self.next(), 0, 0, 0, 0, 0, 0, 0])
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
for d in dest.iter_mut() {
|
||||
*d = self.next();
|
||||
}
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
for d in dest.iter_mut() {
|
||||
*d = self.next();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsafe_buffer_rng() {
|
||||
let seed_bytes = hex::decode("7FFECD162AE2").unwrap();
|
||||
let mut rng = UnsafeBufferRng::new(&seed_bytes);
|
||||
let generated_bytes: Vec<u8> = (0..seed_bytes.len()).map(|_| rng.next()).collect();
|
||||
assert_eq!(seed_bytes, generated_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_approx_exp() {
|
||||
let precision = 1u64 << 14;
|
||||
// known answers were generated with the following sage script:
|
||||
//```sage
|
||||
// num_samples = 10
|
||||
// precision = 200
|
||||
// R = Reals(precision)
|
||||
//
|
||||
// print(f"let kats : [(f64, f64, u64);{num_samples}] = [")
|
||||
// for i in range(num_samples):
|
||||
// x = RDF.random_element(0.0, 0.693147180559945)
|
||||
// ccs = RDF.random_element(0.0, 1.0)
|
||||
// res = round(2^63 * R(ccs) * exp(R(-x)))
|
||||
// print(f"({x}, {ccs}, {res}),")
|
||||
// print("];")
|
||||
// ```
|
||||
let kats: [(f64, f64, u64); 10] = [
|
||||
(0.2314993926072656, 0.8148006314615972, 5962140072160879737),
|
||||
(0.2648875572812225, 0.12769669655309035, 903712282351034505),
|
||||
(0.11251957513682391, 0.9264611470305881, 7635725498677341553),
|
||||
(0.04353439307256617, 0.5306497137523327, 4685877322232397936),
|
||||
(0.41834495299784347, 0.879438856118578, 5338392138535350986),
|
||||
(0.32579398973228557, 0.16513412873289002, 1099603299296456803),
|
||||
(0.5939508073919817, 0.029776019144967303, 151637565622779016),
|
||||
(0.2932367999399056, 0.37123847662857923, 2553827649386670452),
|
||||
(0.5005699297417507, 0.31447208863888976, 1758235618083658825),
|
||||
(0.4876437338498085, 0.6159515298936868, 3488632981903743976),
|
||||
];
|
||||
for (x, ccs, answer) in kats {
|
||||
let difference = (answer as i128) - (approx_exp(x, ccs) as i128);
|
||||
assert!(
|
||||
(difference * difference) as u64 <= precision * precision,
|
||||
"answer: {answer} versus approximation: {}\ndifference: {} whereas precision: {}",
|
||||
approx_exp(x, ccs),
|
||||
difference,
|
||||
precision
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ber_exp() {
|
||||
let kats = [
|
||||
(
|
||||
1.268_314_048_020_498_4,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("ea000000000000").unwrap(),
|
||||
false,
|
||||
),
|
||||
(
|
||||
0.001_563_917_959_143_409_6,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("6c000000000000").unwrap(),
|
||||
true,
|
||||
),
|
||||
(
|
||||
0.017_921_215_753_999_235,
|
||||
0.749_990_853_267_664_9,
|
||||
hex::decode("c2000000000000").unwrap(),
|
||||
false,
|
||||
),
|
||||
(
|
||||
0.776_117_648_844_980_6,
|
||||
0.751_181_554_542_520_8,
|
||||
hex::decode("58000000000000").unwrap(),
|
||||
true,
|
||||
),
|
||||
];
|
||||
for (x, ccs, bytes, answer) in kats {
|
||||
assert_eq!(answer, ber_exp(x, ccs, bytes.try_into().unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampler_z() {
|
||||
let sigma_min = 1.277833697;
|
||||
// known answers from the doc, table 3.2, page 44
|
||||
// https://falcon-sign.info/falcon.pdf
|
||||
// The zeros were added to account for dropped bytes.
|
||||
let kats = [
|
||||
(-91.90471153063714,1.7037990414754918,hex::decode("0fc5442ff043d66e91d1ea000000000000cac64ea5450a22941edc6c").unwrap(),-92),
|
||||
(-8.322564895434937,1.7037990414754918,hex::decode("f4da0f8d8444d1a77265c2000000000000ef6f98bbbb4bee7db8d9b3").unwrap(),-8),
|
||||
(-19.096516109216804,1.7035823083824078,hex::decode("db47f6d7fb9b19f25c36d6000000000000b9334d477a8bc0be68145d").unwrap(),-20),
|
||||
(-11.335543982423326, 1.7035823083824078, hex::decode("ae41b4f5209665c74d00dc000000000000c1a8168a7bb516b3190cb42c1ded26cd52000000000000aed770eca7dd334e0547bcc3c163ce0b").unwrap(), -12),
|
||||
(7.9386734193997555, 1.6984647769450156, hex::decode("31054166c1012780c603ae0000000000009b833cec73f2f41ca5807c000000000000c89c92158834632f9b1555").unwrap(), 8),
|
||||
(-28.990850086867255, 1.6984647769450156, hex::decode("737e9d68a50a06dbbc6477").unwrap(), -30),
|
||||
(-9.071257914091655, 1.6980782114808988, hex::decode("a98ddd14bf0bf22061d632").unwrap(), -10),
|
||||
(-43.88754568839566, 1.6980782114808988, hex::decode("3cbf6818a68f7ab9991514").unwrap(), -41),
|
||||
(-58.17435547946095,1.7010983419195522,hex::decode("6f8633f5bfa5d26848668e0000000000003d5ddd46958e97630410587c").unwrap(),-61),
|
||||
(-43.58664906684732, 1.7010983419195522, hex::decode("272bc6c25f5c5ee53f83c40000000000003a361fbc7cc91dc783e20a").unwrap(), -46),
|
||||
(-34.70565203313315, 1.7009387219711465, hex::decode("45443c59574c2c3b07e2e1000000000000d9071e6d133dbe32754b0a").unwrap(), -34),
|
||||
(-44.36009577368896, 1.7009387219711465, hex::decode("6ac116ed60c258e2cbaeab000000000000728c4823e6da36e18d08da0000000000005d0cc104e21cc7fd1f5ca8000000000000d9dbb675266c928448059e").unwrap(), -44),
|
||||
(-21.783037079346236, 1.6958406126012802, hex::decode("68163bc1e2cbf3e18e7426").unwrap(), -23),
|
||||
(-39.68827784633828, 1.6958406126012802, hex::decode("d6a1b51d76222a705a0259").unwrap(), -40),
|
||||
(-18.488607061056847, 1.6955259305261838, hex::decode("f0523bfaa8a394bf4ea5c10000000000000f842366fde286d6a30803").unwrap(), -22),
|
||||
(-48.39610939101591, 1.6955259305261838, hex::decode("87bd87e63374cee62127fc0000000000006931104aab64f136a0485b").unwrap(), -50),
|
||||
];
|
||||
for (mu, sigma, random_bytes, answer) in kats {
|
||||
assert_eq!(
|
||||
sampler_z(mu, sigma, sigma_min, &mut UnsafeBufferRng::new(&random_bytes)),
|
||||
answer
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
105
src/dsa/rpo_falcon512/mod.rs
Normal file
105
src/dsa/rpo_falcon512/mod.rs
Normal file
|
@ -0,0 +1,105 @@
|
|||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
Felt, Word, ZERO,
|
||||
};
|
||||
|
||||
mod hash_to_point;
|
||||
mod keys;
|
||||
mod math;
|
||||
mod signature;
|
||||
|
||||
pub use self::{
|
||||
keys::{PubKeyPoly, PublicKey, SecretKey},
|
||||
math::Polynomial,
|
||||
signature::{Signature, SignatureHeader, SignaturePoly},
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
// The Falcon modulus p.
|
||||
const MODULUS: i16 = 12289;
|
||||
|
||||
// Number of bits needed to encode an element in the Falcon field.
|
||||
const FALCON_ENCODING_BITS: u32 = 14;
|
||||
|
||||
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
|
||||
// defining the ring Z_p[x]/(phi).
|
||||
const N: usize = 512;
|
||||
const LOG_N: u8 = 9;
|
||||
|
||||
/// Length of nonce used for key-pair generation.
|
||||
const SIG_NONCE_LEN: usize = 40;
|
||||
|
||||
/// Number of filed elements used to encode a nonce.
|
||||
const NONCE_ELEMENTS: usize = 8;
|
||||
|
||||
/// Public key length as a u8 vector.
|
||||
pub const PK_LEN: usize = 897;
|
||||
|
||||
/// Secret key length as a u8 vector.
|
||||
pub const SK_LEN: usize = 1281;
|
||||
|
||||
/// Signature length as a u8 vector.
|
||||
const SIG_POLY_BYTE_LEN: usize = 625;
|
||||
|
||||
/// Bound on the squared-norm of the signature.
|
||||
const SIG_L2_BOUND: u64 = 34034726;
|
||||
|
||||
/// Standard deviation of the Gaussian over the lattice.
|
||||
const SIGMA: f64 = 165.7366171829776;
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
type ShortLatticeBasis = [Polynomial<i16>; 4];
|
||||
|
||||
// NONCE
|
||||
// ================================================================================================
|
||||
|
||||
/// Nonce of the Falcon signature.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Nonce([u8; SIG_NONCE_LEN]);
|
||||
|
||||
impl Nonce {
|
||||
/// Returns a new [Nonce] instantiated from the provided bytes.
|
||||
pub fn new(bytes: [u8; SIG_NONCE_LEN]) -> Self {
|
||||
Self(bytes)
|
||||
}
|
||||
|
||||
/// Returns the underlying bytes of this nonce.
|
||||
pub fn as_bytes(&self) -> &[u8; SIG_NONCE_LEN] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Converts byte representation of the nonce into field element representation.
|
||||
///
|
||||
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
|
||||
/// of the nonce and interpreting them as field elements.
|
||||
pub fn to_elements(&self) -> [Felt; NONCE_ELEMENTS] {
|
||||
let mut buffer = [0_u8; 8];
|
||||
let mut result = [ZERO; 8];
|
||||
for (i, bytes) in self.0.chunks(5).enumerate() {
|
||||
buffer[..5].copy_from_slice(bytes);
|
||||
// we can safely (without overflow) create a new Felt from u64 value here since this
|
||||
// value contains at most 5 bytes
|
||||
result[i] = Felt::new(u64::from_le_bytes(buffer));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &Nonce {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Nonce {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let bytes = source.read()?;
|
||||
Ok(Self(bytes))
|
||||
}
|
||||
}
|
375
src/dsa/rpo_falcon512/signature.rs
Normal file
375
src/dsa/rpo_falcon512/signature.rs
Normal file
|
@ -0,0 +1,375 @@
|
|||
use alloc::{string::ToString, vec::Vec};
|
||||
use core::ops::Deref;
|
||||
|
||||
use num::Zero;
|
||||
|
||||
use super::{
|
||||
hash_to_point::hash_to_point_rpo256,
|
||||
keys::PubKeyPoly,
|
||||
math::{FalconFelt, FastFft, Polynomial},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256,
|
||||
Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN,
|
||||
};
|
||||
|
||||
// FALCON SIGNATURE
|
||||
// ================================================================================================
|
||||
|
||||
/// An RPO Falcon512 signature over a message.
|
||||
///
|
||||
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2 a nonce `r`, and a public
|
||||
/// key polynomial `h` where:
|
||||
/// - p := 12289
|
||||
/// - phi := x^512 + 1
|
||||
///
|
||||
/// The signature verifies against a public key `pk` if and only if:
|
||||
/// 1. s1 = c - s2 * h
|
||||
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
|
||||
///
|
||||
/// where |.| is the norm and:
|
||||
/// - c = HashToPoint(r || message)
|
||||
/// - pk = Rpo256::hash(h)
|
||||
///
|
||||
/// Here h is a polynomial representing the public key and pk is its digest using the Rpo256 hash
|
||||
/// function. c is a polynomial that is the hash-to-point of the message being signed.
|
||||
///
|
||||
/// The polynomial h is serialized as:
|
||||
/// 1. 1 byte representing the log2(512) i.e., 9.
|
||||
/// 2. 896 bytes for the public key itself.
|
||||
///
|
||||
/// The signature is serialized as:
|
||||
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
|
||||
/// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header
|
||||
/// byte is set to `10111001` which differentiates it from the standardized instantiation of the
|
||||
/// Falcon signature.
|
||||
/// 2. 40 bytes for the nonce.
|
||||
/// 4. 625 bytes encoding the `s2` polynomial above.
|
||||
///
|
||||
/// The total size of the signature (including the extended public key) is 1563 bytes.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Signature {
|
||||
header: SignatureHeader,
|
||||
nonce: Nonce,
|
||||
s2: SignaturePoly,
|
||||
h: PubKeyPoly,
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(nonce: Nonce, h: PubKeyPoly, s2: SignaturePoly) -> Signature {
|
||||
Self {
|
||||
header: SignatureHeader::default(),
|
||||
nonce,
|
||||
s2,
|
||||
h,
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the public key polynomial h.
|
||||
pub fn pk_poly(&self) -> &PubKeyPoly {
|
||||
&self.h
|
||||
}
|
||||
|
||||
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
|
||||
pub fn sig_poly(&self) -> &Polynomial<FalconFelt> {
|
||||
&self.s2
|
||||
}
|
||||
|
||||
/// Returns the nonce component of the signature.
|
||||
pub fn nonce(&self) -> &Nonce {
|
||||
&self.nonce
|
||||
}
|
||||
|
||||
// SIGNATURE VERIFICATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this signature is a valid signature for the specified message generated
|
||||
/// against the secret key matching the specified public key commitment.
|
||||
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
|
||||
// compute the hash of the public key polynomial
|
||||
let h_felt: Polynomial<Felt> = (&**self.pk_poly()).into();
|
||||
let h_digest: Word = Rpo256::hash_elements(&h_felt.coefficients).into();
|
||||
if h_digest != pubkey_com {
|
||||
return false;
|
||||
}
|
||||
|
||||
let c = hash_to_point_rpo256(message, &self.nonce);
|
||||
verify_helper(&c, &self.s2, self.pk_poly())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for Signature {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write(&self.header);
|
||||
target.write(&self.nonce);
|
||||
target.write(&self.s2);
|
||||
target.write(&self.h);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Signature {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let header = source.read()?;
|
||||
let nonce = source.read()?;
|
||||
let s2 = source.read()?;
|
||||
let h = source.read()?;
|
||||
|
||||
Ok(Self { header, nonce, s2, h })
|
||||
}
|
||||
}
|
||||
|
||||
// SIGNATURE HEADER
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SignatureHeader(u8);
|
||||
|
||||
impl Default for SignatureHeader {
|
||||
/// According to section 3.11.3 in the specification [1], the signature header has the format
|
||||
/// `0cc1nnnn` where:
|
||||
///
|
||||
/// 1. `cc` signifies the encoding method. `01` denotes using the compression encoding method
|
||||
/// and `10` denotes encoding using the uncompressed method.
|
||||
/// 2. `nnnn` encodes `LOG_N`.
|
||||
///
|
||||
/// For RPO Falcon 512 we use compression encoding and N = 512. Moreover, to differentiate the
|
||||
/// RPO Falcon variant from the reference variant using SHAKE256, we flip the first bit in the
|
||||
/// header. Thus, for RPO Falcon 512 the header is `10111001`
|
||||
///
|
||||
/// [1]: https://falcon-sign.info/falcon.pdf
|
||||
fn default() -> Self {
|
||||
Self(0b1011_1001)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &SignatureHeader {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u8(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SignatureHeader {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let header = source.read_u8()?;
|
||||
let (encoding, log_n) = (header >> 4, header & 0b00001111);
|
||||
if encoding != 0b1011 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: not supported encoding algorithm".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if log_n != LOG_N {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
format!("Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self(header))
|
||||
}
|
||||
}
|
||||
|
||||
// SIGNATURE POLYNOMIAL
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SignaturePoly(pub Polynomial<FalconFelt>);
|
||||
|
||||
impl Deref for SignaturePoly {
|
||||
type Target = Polynomial<FalconFelt>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Polynomial<FalconFelt>> for SignaturePoly {
|
||||
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
|
||||
Self(pk_poly)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[i16; N]> for SignaturePoly {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(coefficients: &[i16; N]) -> Result<Self, Self::Error> {
|
||||
if are_coefficients_valid(coefficients) {
|
||||
Ok(Self(coefficients.to_vec().into()))
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for &SignaturePoly {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
let sig_coeff: Vec<i16> = self.0.coefficients.iter().map(|a| a.balanced_value()).collect();
|
||||
let mut sk_bytes = vec![0_u8; SIG_POLY_BYTE_LEN];
|
||||
|
||||
let mut acc = 0;
|
||||
let mut acc_len = 0;
|
||||
let mut v = 0;
|
||||
let mut t;
|
||||
let mut w;
|
||||
|
||||
// For each coefficient of x:
|
||||
// - the sign is encoded on 1 bit
|
||||
// - the 7 lower bits are encoded naively (binary)
|
||||
// - the high bits are encoded in unary encoding
|
||||
//
|
||||
// Algorithm 17 p. 47 of the specification [1].
|
||||
//
|
||||
// [1]: https://falcon-sign.info/falcon.pdf
|
||||
for &c in sig_coeff.iter() {
|
||||
acc <<= 1;
|
||||
t = c;
|
||||
|
||||
if t < 0 {
|
||||
t = -t;
|
||||
acc |= 1;
|
||||
}
|
||||
w = t as u16;
|
||||
|
||||
acc <<= 7;
|
||||
let mask = 127_u32;
|
||||
acc |= (w as u32) & mask;
|
||||
w >>= 7;
|
||||
|
||||
acc_len += 8;
|
||||
|
||||
acc <<= w + 1;
|
||||
acc |= 1;
|
||||
acc_len += w + 1;
|
||||
|
||||
while acc_len >= 8 {
|
||||
acc_len -= 8;
|
||||
|
||||
sk_bytes[v] = (acc >> acc_len) as u8;
|
||||
v += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if acc_len > 0 {
|
||||
sk_bytes[v] = (acc << (8 - acc_len)) as u8;
|
||||
}
|
||||
target.write_bytes(&sk_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SignaturePoly {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let input = source.read_array::<SIG_POLY_BYTE_LEN>()?;
|
||||
|
||||
let mut input_idx = 0;
|
||||
let mut acc = 0u32;
|
||||
let mut acc_len = 0;
|
||||
let mut coefficients = [FalconFelt::zero(); N];
|
||||
|
||||
// Algorithm 18 p. 48 of the specification [1].
|
||||
//
|
||||
// [1]: https://falcon-sign.info/falcon.pdf
|
||||
for c in coefficients.iter_mut() {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
let b = acc >> acc_len;
|
||||
let s = b & 128;
|
||||
let mut m = b & 127;
|
||||
|
||||
loop {
|
||||
if acc_len == 0 {
|
||||
acc = (acc << 8) | (input[input_idx] as u32);
|
||||
input_idx += 1;
|
||||
acc_len = 8;
|
||||
}
|
||||
acc_len -= 1;
|
||||
if ((acc >> acc_len) & 1) != 0 {
|
||||
break;
|
||||
}
|
||||
m += 128;
|
||||
if m >= 2048 {
|
||||
return Err(DeserializationError::InvalidValue(format!(
|
||||
"Failed to decode signature: high bits {m} exceed 2048",
|
||||
)));
|
||||
}
|
||||
}
|
||||
if s != 0 && m == 0 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: -0 is forbidden".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let felt = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
|
||||
*c = FalconFelt::new(felt as i16);
|
||||
}
|
||||
|
||||
if (acc & ((1 << acc_len) - 1)) != 0 {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"Failed to decode signature: Non-zero unused bits in the last byte".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Polynomial::new(coefficients.to_vec()).into())
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Takes the hash-to-point polynomial `c` of a message, the signature polynomial over
|
||||
/// the message `s2` and a public key polynomial and returns `true` is the signature is a valid
|
||||
/// signature for the given parameters, otherwise it returns `false`.
|
||||
fn verify_helper(c: &Polynomial<FalconFelt>, s2: &SignaturePoly, h: &PubKeyPoly) -> bool {
|
||||
let h_fft = h.fft();
|
||||
let s2_fft = s2.fft();
|
||||
let c_fft = c.fft();
|
||||
|
||||
// compute the signature polynomial s1 using s1 = c - s2 * h
|
||||
let s1_fft = c_fft - s2_fft.hadamard_mul(&h_fft);
|
||||
let s1 = s1_fft.ifft();
|
||||
|
||||
// compute the norm squared of (s1, s2)
|
||||
let length_squared_s1 = s1.norm_squared();
|
||||
let length_squared_s2 = s2.norm_squared();
|
||||
let length_squared = length_squared_s1 + length_squared_s2;
|
||||
|
||||
length_squared < SIG_L2_BOUND
|
||||
}
|
||||
|
||||
/// Checks whether a set of coefficients is a valid one for a signature polynomial.
|
||||
fn are_coefficients_valid(x: &[i16]) -> bool {
|
||||
if x.len() != N {
|
||||
return false;
|
||||
}
|
||||
|
||||
for &c in x {
|
||||
if !(-2047..=2047).contains(&c) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha20Rng;
|
||||
|
||||
use super::{super::SecretKey, *};
|
||||
|
||||
#[test]
|
||||
fn test_serialization_round_trip() {
|
||||
let seed = [0_u8; 32];
|
||||
let mut rng = ChaCha20Rng::from_seed(seed);
|
||||
|
||||
let sk = SecretKey::with_rng(&mut rng);
|
||||
let signature = sk.sign_with_rng(Word::default(), &mut rng);
|
||||
let serialized = signature.to_bytes();
|
||||
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
|
||||
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
|
||||
}
|
||||
}
|
383
src/hash/blake/mod.rs
Normal file
383
src/hash/blake/mod.rs
Normal file
|
@ -0,0 +1,383 @@
|
|||
use alloc::{string::String, vec::Vec};
|
||||
use core::{
|
||||
mem::{size_of, transmute, transmute_copy},
|
||||
ops::Deref,
|
||||
slice::{self, from_raw_parts},
|
||||
};
|
||||
|
||||
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
|
||||
use crate::utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const DIGEST32_BYTES: usize = 32;
|
||||
const DIGEST24_BYTES: usize = 24;
|
||||
const DIGEST20_BYTES: usize = 20;
|
||||
|
||||
// BLAKE3 N-BIT OUTPUT
|
||||
// ================================================================================================
|
||||
|
||||
/// N-bytes output of a blake3 function.
|
||||
///
|
||||
/// Note: `N` can't be greater than `32` because [`Digest::as_bytes`] currently supports only 32
|
||||
/// bytes.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct Blake3Digest<const N: usize>([u8; N]);
|
||||
|
||||
impl<const N: usize> Blake3Digest<N> {
|
||||
pub fn digests_as_bytes(digests: &[Blake3Digest<N>]) -> &[u8] {
|
||||
let p = digests.as_ptr();
|
||||
let len = digests.len() * N;
|
||||
unsafe { slice::from_raw_parts(p as *const u8, len) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Default for Blake3Digest<N> {
|
||||
fn default() -> Self {
|
||||
Self([0; N])
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Deref for Blake3Digest<N> {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> From<Blake3Digest<N>> for [u8; N] {
|
||||
fn from(value: Blake3Digest<N>) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> From<[u8; N]> for Blake3Digest<N> {
|
||||
fn from(value: [u8; N]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> From<Blake3Digest<N>> for String {
|
||||
fn from(value: Blake3Digest<N>) -> Self {
|
||||
bytes_to_hex_string(value.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> TryFrom<&str> for Blake3Digest<N> {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes(value).map(|v| v.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Serializable for Blake3Digest<N> {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Deserializable for Blake3Digest<N> {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
source.read_array().map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Digest for Blake3Digest<N> {
|
||||
fn as_bytes(&self) -> [u8; 32] {
|
||||
// compile-time assertion
|
||||
assert!(N <= 32, "digest currently supports only 32 bytes!");
|
||||
expand_bytes(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
// BLAKE3 256-BIT OUTPUT
|
||||
// ================================================================================================
|
||||
|
||||
/// 256-bit output blake3 hasher.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Blake3_256;
|
||||
|
||||
impl Hasher for Blake3_256 {
|
||||
/// Blake3 collision resistance is 128-bits for 32-bytes output.
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = Blake3Digest<32>;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
Blake3Digest(blake3::hash(bytes).into())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
Self::hash(prepare_merge(values))
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
Blake3Digest(blake3::hash(Blake3Digest::digests_as_bytes(values)).into())
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(&seed.0);
|
||||
hasher.update(&value.to_le_bytes());
|
||||
Blake3Digest(hasher.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Blake3_256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E>(elements: &[E]) -> Self::Digest
|
||||
where
|
||||
E: FieldElement<BaseField = Self::BaseField>,
|
||||
{
|
||||
Blake3Digest(hash_elements(elements))
|
||||
}
|
||||
}
|
||||
|
||||
impl Blake3_256 {
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> Blake3Digest<DIGEST32_BYTES> {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[Blake3Digest<DIGEST32_BYTES>; 2]) -> Blake3Digest<DIGEST32_BYTES> {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E>(elements: &[E]) -> Blake3Digest<DIGEST32_BYTES>
|
||||
where
|
||||
E: FieldElement<BaseField = Felt>,
|
||||
{
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
}
|
||||
|
||||
// BLAKE3 192-BIT OUTPUT
|
||||
// ================================================================================================
|
||||
|
||||
/// 192-bit output blake3 hasher.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Blake3_192;
|
||||
|
||||
impl Hasher for Blake3_192 {
|
||||
/// Blake3 collision resistance is 96-bits for 24-bytes output.
|
||||
const COLLISION_RESISTANCE: u32 = 96;
|
||||
|
||||
type Digest = Blake3Digest<24>;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
Self::hash(prepare_merge(values))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(&seed.0);
|
||||
hasher.update(&value.to_le_bytes());
|
||||
Blake3Digest(*shrink_bytes(&hasher.finalize().into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Blake3_192 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E>(elements: &[E]) -> Self::Digest
|
||||
where
|
||||
E: FieldElement<BaseField = Self::BaseField>,
|
||||
{
|
||||
Blake3Digest(hash_elements(elements))
|
||||
}
|
||||
}
|
||||
|
||||
impl Blake3_192 {
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> Blake3Digest<DIGEST24_BYTES> {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[Blake3Digest<DIGEST24_BYTES>; 2]) -> Blake3Digest<DIGEST24_BYTES> {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E>(elements: &[E]) -> Blake3Digest<DIGEST24_BYTES>
|
||||
where
|
||||
E: FieldElement<BaseField = Felt>,
|
||||
{
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
}
|
||||
|
||||
// BLAKE3 160-BIT OUTPUT
|
||||
// ================================================================================================
|
||||
|
||||
/// 160-bit output blake3 hasher.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Blake3_160;
|
||||
|
||||
impl Hasher for Blake3_160 {
|
||||
/// Blake3 collision resistance is 80-bits for 20-bytes output.
|
||||
const COLLISION_RESISTANCE: u32 = 80;
|
||||
|
||||
type Digest = Blake3Digest<20>;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
Self::hash(prepare_merge(values))
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
|
||||
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(&seed.0);
|
||||
hasher.update(&value.to_le_bytes());
|
||||
Blake3Digest(*shrink_bytes(&hasher.finalize().into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Blake3_160 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E>(elements: &[E]) -> Self::Digest
|
||||
where
|
||||
E: FieldElement<BaseField = Self::BaseField>,
|
||||
{
|
||||
Blake3Digest(hash_elements(elements))
|
||||
}
|
||||
}
|
||||
|
||||
impl Blake3_160 {
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> Blake3Digest<DIGEST20_BYTES> {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[Blake3Digest<DIGEST20_BYTES>; 2]) -> Blake3Digest<DIGEST20_BYTES> {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E>(elements: &[E]) -> Blake3Digest<DIGEST20_BYTES>
|
||||
where
|
||||
E: FieldElement<BaseField = Felt>,
|
||||
{
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Zero-copy ref shrink to array.
|
||||
fn shrink_bytes<const M: usize, const N: usize>(bytes: &[u8; M]) -> &[u8; N] {
|
||||
// compile-time assertion
|
||||
assert!(M >= N, "N should fit in M so it can be safely transmuted into a smaller slice!");
|
||||
// safety: bytes len is asserted
|
||||
unsafe { transmute(bytes) }
|
||||
}
|
||||
|
||||
/// Hash the elements into bytes and shrink the output.
|
||||
fn hash_elements<const N: usize, E>(elements: &[E]) -> [u8; N]
|
||||
where
|
||||
E: FieldElement<BaseField = Felt>,
|
||||
{
|
||||
// don't leak assumptions from felt and check its actual implementation.
|
||||
// this is a compile-time branch so it is for free
|
||||
let digest = if Felt::IS_CANONICAL {
|
||||
blake3::hash(E::elements_as_bytes(elements))
|
||||
} else {
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
|
||||
// BLAKE3 state is 64 bytes - so, we can absorb 64 bytes into the state in a single
|
||||
// permutation. we move the elements into the hasher via the buffer to give the CPU
|
||||
// a chance to process multiple element-to-byte conversions in parallel
|
||||
let mut buf = [0_u8; 64];
|
||||
let mut chunk_iter = E::slice_as_base_elements(elements).chunks_exact(8);
|
||||
for chunk in chunk_iter.by_ref() {
|
||||
for i in 0..8 {
|
||||
buf[i * 8..(i + 1) * 8].copy_from_slice(&chunk[i].as_int().to_le_bytes());
|
||||
}
|
||||
hasher.update(&buf);
|
||||
}
|
||||
|
||||
for element in chunk_iter.remainder() {
|
||||
hasher.update(&element.as_int().to_le_bytes());
|
||||
}
|
||||
|
||||
hasher.finalize()
|
||||
};
|
||||
*shrink_bytes(&digest.into())
|
||||
}
|
||||
|
||||
/// Owned bytes expansion.
|
||||
fn expand_bytes<const M: usize, const N: usize>(bytes: &[u8; M]) -> [u8; N] {
|
||||
// compile-time assertion
|
||||
assert!(M <= N, "M should fit in N so M can be expanded!");
|
||||
// this branch is constant so it will be optimized to be either one of the variants in release
|
||||
// mode
|
||||
if M == N {
|
||||
// safety: the sizes are checked to be the same
|
||||
unsafe { transmute_copy(bytes) }
|
||||
} else {
|
||||
let mut expanded = [0u8; N];
|
||||
expanded[..M].copy_from_slice(bytes);
|
||||
expanded
|
||||
}
|
||||
}
|
||||
|
||||
// Cast the slice into contiguous bytes.
|
||||
fn prepare_merge<const N: usize, D>(args: &[D; N]) -> &[u8]
|
||||
where
|
||||
D: Deref<Target = [u8]>,
|
||||
{
|
||||
// compile-time assertion
|
||||
assert!(N > 0, "N shouldn't represent an empty slice!");
|
||||
let values = args.as_ptr() as *const u8;
|
||||
let len = size_of::<D>() * N;
|
||||
// safety: the values are tested to be contiguous
|
||||
let bytes = unsafe { from_raw_parts(values, len) };
|
||||
debug_assert_eq!(args[0].deref(), &bytes[..len / N]);
|
||||
bytes
|
||||
}
|
49
src/hash/blake/tests.rs
Normal file
49
src/hash/blake/tests.rs
Normal file
|
@ -0,0 +1,49 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_vector;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn blake3_hash_elements() {
|
||||
// test multiple of 8
|
||||
let elements = rand_vector::<Felt>(16);
|
||||
let expected = compute_expected_element_hash(&elements);
|
||||
let actual: [u8; 32] = hash_elements(&elements);
|
||||
assert_eq!(&expected, &actual);
|
||||
|
||||
// test not multiple of 8
|
||||
let elements = rand_vector::<Felt>(17);
|
||||
let expected = compute_expected_element_hash(&elements);
|
||||
let actual: [u8; 32] = hash_elements(&elements);
|
||||
assert_eq!(&expected, &actual);
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn blake160_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
|
||||
Blake3_160::hash(vec);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blake192_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
|
||||
Blake3_192::hash(vec);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blake256_wont_panic_with_arbitrary_input(ref vec in any::<Vec<u8>>()) {
|
||||
Blake3_256::hash(vec);
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
fn compute_expected_element_hash(elements: &[Felt]) -> blake3::Hash {
|
||||
let mut bytes = Vec::new();
|
||||
for element in elements.iter() {
|
||||
bytes.extend_from_slice(&element.as_int().to_le_bytes());
|
||||
}
|
||||
blake3::hash(&bytes)
|
||||
}
|
19
src/hash/mod.rs
Normal file
19
src/hash/mod.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
|
||||
|
||||
use super::{CubeExtension, Felt, FieldElement, StarkField, ZERO};
|
||||
|
||||
pub mod blake;
|
||||
|
||||
mod rescue;
|
||||
pub mod rpo {
|
||||
pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError};
|
||||
}
|
||||
|
||||
pub mod rpx {
|
||||
pub use super::rescue::{Rpx256, RpxDigest, RpxDigestError};
|
||||
}
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
pub use winter_crypto::{Digest, ElementHasher, Hasher};
|
101
src/hash/rescue/arch/mod.rs
Normal file
101
src/hash/rescue/arch/mod.rs
Normal file
|
@ -0,0 +1,101 @@
|
|||
#[cfg(target_feature = "sve")]
|
||||
pub mod optimized {
|
||||
use crate::{hash::rescue::STATE_WIDTH, Felt};
|
||||
|
||||
mod ffi {
|
||||
#[link(name = "rpo_sve", kind = "static")]
|
||||
extern "C" {
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: *mut std::ffi::c_ulong,
|
||||
constants: *const std::ffi::c_ulong,
|
||||
) -> bool;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
unsafe {
|
||||
ffi::add_constants_and_apply_inv_sbox(
|
||||
state.as_mut_ptr() as *mut u64,
|
||||
ark.as_ptr() as *const u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
mod x86_64_avx2;
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
pub mod optimized {
|
||||
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
|
||||
use crate::{
|
||||
hash::rescue::{add_constants, STATE_WIDTH},
|
||||
Felt,
|
||||
};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
state: &mut [Felt; STATE_WIDTH],
|
||||
ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
add_constants(state, ark);
|
||||
unsafe {
|
||||
apply_inv_sbox(std::mem::transmute(state));
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_feature = "avx2", target_feature = "sve")))]
|
||||
pub mod optimized {
|
||||
use crate::{hash::rescue::STATE_WIDTH, Felt};
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_constants_and_apply_inv_sbox(
|
||||
_state: &mut [Felt; STATE_WIDTH],
|
||||
_ark: &[Felt; STATE_WIDTH],
|
||||
) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
328
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
328
src/hash/rescue/arch/x86_64_avx2.rs
Normal file
|
@ -0,0 +1,328 @@
|
|||
use core::arch::x86_64::*;
|
||||
|
||||
// The following AVX2 implementation has been copied from plonky2:
|
||||
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
|
||||
|
||||
// Preliminary notes:
|
||||
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily emulated.
|
||||
// The method recognizes that for a + b overflowed iff (a + b) < a:
|
||||
// 1. res_lo = a_lo + b_lo
|
||||
// 2. carry_mask = res_lo < a_lo
|
||||
// 3. res_hi = a_hi + b_hi - carry_mask
|
||||
//
|
||||
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
|
||||
// return -1 (all bits 1) for true and 0 for false.
|
||||
//
|
||||
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
|
||||
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
|
||||
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts 1
|
||||
// << 63 to enable this trick. Addition with carry example:
|
||||
// 1. a_lo_s = shift(a_lo)
|
||||
// 2. res_lo_s = a_lo_s + b_lo
|
||||
// 3. carry_mask = res_lo_s <s a_lo_s
|
||||
// 4. res_lo = shift(res_lo_s)
|
||||
// 5. res_hi = a_hi + b_hi - carry_mask
|
||||
//
|
||||
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition
|
||||
// is shifted if exactly one of the operands is shifted, as is the case on
|
||||
// line 2. Line 3. performs a signed comparison res_lo_s <s a_lo_s on shifted values to
|
||||
// emulate unsigned comparison res_lo <u a_lo on unshifted values. Finally, line 4. reverses the
|
||||
// shift so the result can be returned.
|
||||
//
|
||||
// When performing a chain of calculations, we can often save instructions by letting
|
||||
// the shift propagate through and only undoing it when necessary.
|
||||
// For example, to compute the addition of three two-word (128-bit) numbers we can do:
|
||||
// 1. a_lo_s = shift(a_lo)
|
||||
// 2. tmp_lo_s = a_lo_s + b_lo
|
||||
// 3. tmp_carry_mask = tmp_lo_s <s a_lo_s
|
||||
// 4. tmp_hi = a_hi + b_hi - tmp_carry_mask
|
||||
// 5. res_lo_s = tmp_lo_s + c_lo vi. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// 6. res_carry_mask = res_lo_s <s tmp_lo_s
|
||||
// 7. res_lo = shift(res_lo_s)
|
||||
// 8. res_hi = tmp_hi + c_hi - res_carry_mask
|
||||
//
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
#[inline(always)]
|
||||
pub fn branch_hint() {
|
||||
// NOTE: These are the currently supported assembly architectures. See the
|
||||
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
|
||||
// the most up-to-date list.
|
||||
#[cfg(any(
|
||||
target_arch = "aarch64",
|
||||
target_arch = "arm",
|
||||
target_arch = "riscv32",
|
||||
target_arch = "riscv64",
|
||||
target_arch = "x86",
|
||||
target_arch = "x86_64",
|
||||
))]
|
||||
unsafe {
|
||||
core::arch::asm!("", options(nomem, nostack, preserves_flags));
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map3 {
|
||||
($f:ident:: < $l:literal > , $v:ident) => {
|
||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||
};
|
||||
($f:ident:: < $l:literal > , $v1:ident, $v2:ident) => {
|
||||
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
|
||||
};
|
||||
($f:ident, $v:ident) => {
|
||||
($f($v.0), $f($v.1), $f($v.2))
|
||||
};
|
||||
($f:ident, $v0:ident, $v1:ident) => {
|
||||
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||
};
|
||||
($f:ident,rep $v0:ident, $v1:ident) => {
|
||||
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
|
||||
};
|
||||
|
||||
($f:ident, $v0:ident,rep $v1:ident) => {
|
||||
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
|
||||
};
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
|
||||
// All pairwise multiplications.
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
||||
|
||||
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
|
||||
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
|
||||
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||
// position).
|
||||
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
||||
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let x_hi = {
|
||||
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||
// This is safe and free.
|
||||
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||
map3!(_mm256_castps_si256, x_hi_ps)
|
||||
};
|
||||
let y_hi = {
|
||||
let y_ps = map3!(_mm256_castsi256_ps, y);
|
||||
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
|
||||
map3!(_mm256_castps_si256, y_hi_ps)
|
||||
};
|
||||
|
||||
// All four pairwise multiplications
|
||||
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
||||
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
||||
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
||||
|
||||
// Bignum addition
|
||||
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
|
||||
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
|
||||
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
|
||||
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
|
||||
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
|
||||
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
|
||||
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
|
||||
|
||||
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||
// position).
|
||||
let t1_lo = {
|
||||
let t1_ps = map3!(_mm256_castsi256_ps, t1);
|
||||
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
|
||||
map3!(_mm256_castps_si256, t1_lo_ps)
|
||||
};
|
||||
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
|
||||
|
||||
(res_lo, res_hi)
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn add_small(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
|
||||
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
|
||||
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
|
||||
// The subtraction is very unlikely to overflow so we're best off branching.
|
||||
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
|
||||
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
|
||||
// floating-point (this is free).
|
||||
let mask_pd = _mm256_castsi256_pd(mask);
|
||||
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
|
||||
// did not occur for any of the vector elements.
|
||||
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
|
||||
res_wrapped_s
|
||||
} else {
|
||||
branch_hint();
|
||||
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
|
||||
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
|
||||
_mm256_sub_epi64(res_wrapped_s, adj_amount)
|
||||
}
|
||||
}
|
||||
|
||||
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
|
||||
#[inline(always)]
|
||||
unsafe fn sub_tiny(
|
||||
x_s: (__m256i, __m256i, __m256i),
|
||||
y: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
|
||||
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
|
||||
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
|
||||
res_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn reduce3(
|
||||
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
|
||||
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
|
||||
let lo1_s = sub_tiny(lo0_s, hi_hi0);
|
||||
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
|
||||
let lo2_s = add_small(lo1_s, t1);
|
||||
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
|
||||
lo2
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mul_reduce(
|
||||
a: (__m256i, __m256i, __m256i),
|
||||
b: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(mul3(a, b))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
reduce3(square3(state))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn exp_acc(
|
||||
high: (__m256i, __m256i, __m256i),
|
||||
low: (__m256i, __m256i, __m256i),
|
||||
exp: usize,
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let mut result = high;
|
||||
for _ in 0..exp {
|
||||
result = square_reduce(result);
|
||||
}
|
||||
mul_reduce(result, low)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
let state2 = square_reduce(state);
|
||||
let state4_unreduced = square3(state2);
|
||||
let state3_unreduced = mul3(state2, state);
|
||||
let state4 = reduce3(state4_unreduced);
|
||||
let state3 = reduce3(state3_unreduced);
|
||||
let state7_unreduced = mul3(state3, state4);
|
||||
let state7 = reduce3(state7_unreduced);
|
||||
state7
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let t1 = square_reduce(state);
|
||||
|
||||
// compute base^100
|
||||
let t2 = square_reduce(t1);
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc(t2, t2, 3);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc(t3, t3, 6);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc(t4, t4, 12);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc(t5, t3, 6);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc(t6, t6, 31);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
|
||||
let b = mul_reduce(t1, mul_reduce(t2, state));
|
||||
mul_reduce(a, b)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
|
||||
(
|
||||
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
|
||||
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
|
||||
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
|
||||
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
|
||||
let mut state = avx2_load(&buffer);
|
||||
state = do_apply_inv_sbox(state);
|
||||
avx2_store(buffer, state);
|
||||
}
|
197
src/hash/rescue/mds/freq.rs
Normal file
197
src/hash/rescue/mds/freq.rs
Normal file
|
@ -0,0 +1,197 @@
|
|||
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
//! This module contains helper functions as well as constants used to perform the vector-matrix
|
||||
//! multiplication step of the Rescue prime permutation. The special form of our MDS matrix
|
||||
//! i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
|
||||
//! of two vectors in "frequency domain". This follows from the simple fact that every circulant
|
||||
//! matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
|
||||
//! The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
|
||||
//! with explicit expressions. It also avoids, due to the form of our matrix in the frequency
|
||||
//! domain, divisions by 2 and repeated modular reductions. This is because of our explicit choice
|
||||
//! of an MDS matrix that has small powers of 2 entries in frequency domain.
|
||||
//! The following implementation has benefited greatly from the discussions and insights of
|
||||
//! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
|
||||
//! implementation.
|
||||
|
||||
// Rescue MDS matrix in frequency domain.
|
||||
//
|
||||
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
|
||||
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
|
||||
// and application of the final four 3-point FFT in order to get the full 12-point FFT.
|
||||
// The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4.
|
||||
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to
|
||||
// generate MDS matrices efficiently in original domain, that was developed by the Polygon Zero
|
||||
// team.
|
||||
const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16];
|
||||
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)];
|
||||
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
|
||||
|
||||
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
|
||||
#[inline(always)]
|
||||
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
|
||||
|
||||
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
|
||||
let (u4, u5, u6) = fft4_real([s1, s4, s7, s10]);
|
||||
let (u8, u9, u10) = fft4_real([s2, s5, s8, s11]);
|
||||
|
||||
// This where the multiplication in frequency domain is done. More precisely, and with
|
||||
// the appropriate permutations in between, the sequence of
|
||||
// 3-point FFTs --> multiplication by twiddle factors --> Hadamard multiplication -->
|
||||
// 3 point iFFTs --> multiplication by (inverse) twiddle factors
|
||||
// is "squashed" into one step composed of the functions "block1", "block2" and "block3".
|
||||
// The expressions in the aforementioned functions are the result of explicit computations
|
||||
// combined with the Karatsuba trick for the multiplication of Complex numbers.
|
||||
|
||||
let [v0, v4, v8] = block1([u0, u4, u8], MDS_FREQ_BLOCK_ONE);
|
||||
let [v1, v5, v9] = block2([u1, u5, u9], MDS_FREQ_BLOCK_TWO);
|
||||
let [v2, v6, v10] = block3([u2, u6, u10], MDS_FREQ_BLOCK_THREE);
|
||||
// The 4th block is not computed as it is similar to the 2nd one, up to complex conjugation,
|
||||
// and is, due to the use of the real FFT and iFFT, redundant.
|
||||
|
||||
let [s0, s3, s6, s9] = ifft4_real((v0, v1, v2));
|
||||
let [s1, s4, s7, s10] = ifft4_real((v4, v5, v6));
|
||||
let [s2, s5, s8, s11] = ifft4_real((v8, v9, v10));
|
||||
|
||||
[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
|
||||
}
|
||||
|
||||
// We use the real FFT to avoid redundant computations. See https://www.mdpi.com/2076-3417/12/9/4700
|
||||
#[inline(always)]
|
||||
const fn fft2_real(x: [u64; 2]) -> [i64; 2] {
|
||||
[(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)]
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
const fn ifft2_real(y: [i64; 2]) -> [u64; 2] {
|
||||
// We avoid divisions by 2 by appropriately scaling the MDS matrix constants.
|
||||
[(y[0] + y[1]) as u64, (y[0] - y[1]) as u64]
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
const fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) {
|
||||
let [z0, z2] = fft2_real([x[0], x[2]]);
|
||||
let [z1, z3] = fft2_real([x[1], x[3]]);
|
||||
let y0 = z0 + z1;
|
||||
let y1 = (z2, -z3);
|
||||
let y2 = z0 - z1;
|
||||
(y0, y1, y2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
const fn ifft4_real(y: (i64, (i64, i64), i64)) -> [u64; 4] {
|
||||
// In calculating 'z0' and 'z1', division by 2 is avoided by appropriately scaling
|
||||
// the MDS matrix constants.
|
||||
let z0 = y.0 + y.2;
|
||||
let z1 = y.0 - y.2;
|
||||
let z2 = y.1 .0;
|
||||
let z3 = -y.1 .1;
|
||||
|
||||
let [x0, x2] = ifft2_real([z0, z2]);
|
||||
let [x1, x3] = ifft2_real([z1, z3]);
|
||||
|
||||
[x0, x1, x2, x3]
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
const fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
|
||||
let [x0, x1, x2] = x;
|
||||
let [y0, y1, y2] = y;
|
||||
let z0 = x0 * y0 + x1 * y2 + x2 * y1;
|
||||
let z1 = x0 * y1 + x1 * y0 + x2 * y2;
|
||||
let z2 = x0 * y2 + x1 * y1 + x2 * y0;
|
||||
|
||||
[z0, z1, z2]
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
const fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] {
|
||||
let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x;
|
||||
let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y;
|
||||
let x0s = x0r + x0i;
|
||||
let x1s = x1r + x1i;
|
||||
let x2s = x2r + x2i;
|
||||
let y0s = y0r + y0i;
|
||||
let y1s = y1r + y1i;
|
||||
let y2s = y2r + y2i;
|
||||
|
||||
// Compute x0y0 − ix1y2 − ix2y1 using Karatsuba for complex numbers multiplication
|
||||
let m0 = (x0r * y0r, x0i * y0i);
|
||||
let m1 = (x1r * y2r, x1i * y2i);
|
||||
let m2 = (x2r * y1r, x2i * y1i);
|
||||
let z0r = (m0.0 - m0.1) + (x1s * y2s - m1.0 - m1.1) + (x2s * y1s - m2.0 - m2.1);
|
||||
let z0i = (x0s * y0s - m0.0 - m0.1) + (-m1.0 + m1.1) + (-m2.0 + m2.1);
|
||||
let z0 = (z0r, z0i);
|
||||
|
||||
// Compute x0y1 + x1y0 − ix2y2 using Karatsuba for complex numbers multiplication
|
||||
let m0 = (x0r * y1r, x0i * y1i);
|
||||
let m1 = (x1r * y0r, x1i * y0i);
|
||||
let m2 = (x2r * y2r, x2i * y2i);
|
||||
let z1r = (m0.0 - m0.1) + (m1.0 - m1.1) + (x2s * y2s - m2.0 - m2.1);
|
||||
let z1i = (x0s * y1s - m0.0 - m0.1) + (x1s * y0s - m1.0 - m1.1) + (-m2.0 + m2.1);
|
||||
let z1 = (z1r, z1i);
|
||||
|
||||
// Compute x0y2 + x1y1 + x2y0 using Karatsuba for complex numbers multiplication
|
||||
let m0 = (x0r * y2r, x0i * y2i);
|
||||
let m1 = (x1r * y1r, x1i * y1i);
|
||||
let m2 = (x2r * y0r, x2i * y0i);
|
||||
let z2r = (m0.0 - m0.1) + (m1.0 - m1.1) + (m2.0 - m2.1);
|
||||
let z2i = (x0s * y2s - m0.0 - m0.1) + (x1s * y1s - m1.0 - m1.1) + (x2s * y0s - m2.0 - m2.1);
|
||||
let z2 = (z2r, z2i);
|
||||
|
||||
[z0, z1, z2]
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
|
||||
let [x0, x1, x2] = x;
|
||||
let [y0, y1, y2] = y;
|
||||
let z0 = x0 * y0 - x1 * y2 - x2 * y1;
|
||||
let z1 = x0 * y1 + x1 * y0 - x2 * y2;
|
||||
let z2 = x0 * y2 + x1 * y1 + x2 * y0;
|
||||
|
||||
[z0, z1, z2]
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::super::{apply_mds, Felt, MDS, ZERO};
|
||||
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_mds_naive(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
result.iter_mut().zip(MDS).for_each(|(r, mds_row)| {
|
||||
state.iter().zip(mds_row).for_each(|(&s, m)| {
|
||||
*r += m * s;
|
||||
});
|
||||
});
|
||||
*state = result;
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {
|
||||
|
||||
let mut v1 = [ZERO; STATE_WIDTH];
|
||||
let mut v2;
|
||||
|
||||
for i in 0..STATE_WIDTH {
|
||||
v1[i] = Felt::new(a[i]);
|
||||
}
|
||||
v2 = v1;
|
||||
|
||||
apply_mds_naive(&mut v1);
|
||||
apply_mds(&mut v2);
|
||||
|
||||
prop_assert_eq!(v1, v2);
|
||||
}
|
||||
}
|
||||
}
|
214
src/hash/rescue/mds/mod.rs
Normal file
214
src/hash/rescue/mds/mod.rs
Normal file
|
@ -0,0 +1,214 @@
|
|||
use super::{Felt, STATE_WIDTH, ZERO};
|
||||
|
||||
mod freq;
|
||||
pub use freq::mds_multiply_freq;
|
||||
|
||||
// MDS MULTIPLICATION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
|
||||
let mut result = [ZERO; STATE_WIDTH];
|
||||
|
||||
// Using the linearity of the operations we can split the state into a low||high decomposition
|
||||
// and operate on each with no overflow and then combine/reduce the result to a field element.
|
||||
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
|
||||
// frequency domain.
|
||||
let mut state_l = [0u64; STATE_WIDTH];
|
||||
let mut state_h = [0u64; STATE_WIDTH];
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state[r].inner();
|
||||
state_h[r] = s >> 32;
|
||||
state_l[r] = (s as u32) as u64;
|
||||
}
|
||||
|
||||
let state_h = mds_multiply_freq(state_h);
|
||||
let state_l = mds_multiply_freq(state_l);
|
||||
|
||||
for r in 0..STATE_WIDTH {
|
||||
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
|
||||
let s_hi = (s >> 64) as u64;
|
||||
let s_lo = s as u64;
|
||||
let z = (s_hi << 32) - s_hi;
|
||||
let (res, over) = s_lo.overflowing_add(z);
|
||||
|
||||
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
|
||||
}
|
||||
*state = result;
|
||||
}
|
||||
|
||||
// MDS MATRIX
|
||||
// ================================================================================================
|
||||
|
||||
/// RPO MDS matrix
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
],
|
||||
[
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
],
|
||||
[
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
],
|
||||
[
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
],
|
||||
[
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
],
|
||||
[
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
],
|
||||
[
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
],
|
||||
[
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
],
|
||||
[
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
],
|
||||
[
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
Felt::new(23),
|
||||
],
|
||||
[
|
||||
Felt::new(23),
|
||||
Felt::new(8),
|
||||
Felt::new(26),
|
||||
Felt::new(13),
|
||||
Felt::new(10),
|
||||
Felt::new(9),
|
||||
Felt::new(7),
|
||||
Felt::new(6),
|
||||
Felt::new(22),
|
||||
Felt::new(21),
|
||||
Felt::new(8),
|
||||
Felt::new(7),
|
||||
],
|
||||
];
|
347
src/hash/rescue/mod.rs
Normal file
347
src/hash/rescue/mod.rs
Normal file
|
@ -0,0 +1,347 @@
|
|||
use core::ops::Range;
|
||||
|
||||
use super::{CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ZERO};
|
||||
|
||||
mod arch;
|
||||
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
|
||||
|
||||
mod mds;
|
||||
use mds::{apply_mds, MDS};
|
||||
|
||||
mod rpo;
|
||||
pub use rpo::{Rpo256, RpoDigest, RpoDigestError};
|
||||
|
||||
mod rpx;
|
||||
pub use rpx::{Rpx256, RpxDigest, RpxDigestError};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
|
||||
/// RPX hash function, there are 3 different types of rounds.
|
||||
const NUM_ROUNDS: usize = 7;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
const STATE_WIDTH: usize = 12;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11.
|
||||
const RATE_RANGE: Range<usize> = 4..12;
|
||||
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
|
||||
|
||||
const INPUT1_RANGE: Range<usize> = 4..8;
|
||||
const INPUT2_RANGE: Range<usize> = 8..12;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
const CAPACITY_RANGE: Range<usize> = 0..4;
|
||||
|
||||
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
|
||||
///
|
||||
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
|
||||
/// rate portion).
|
||||
const DIGEST_RANGE: Range<usize> = 4..8;
|
||||
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
|
||||
|
||||
/// The number of bytes needed to encoded a digest
|
||||
const DIGEST_BYTES: usize = 32;
|
||||
|
||||
/// The number of byte chunks defining a field element when hashing a sequence of bytes
|
||||
const BINARY_CHUNK_SIZE: usize = 7;
|
||||
|
||||
/// S-Box and Inverse S-Box powers;
|
||||
///
|
||||
/// The constants are defined for tests only because the exponentiations in the code are unrolled
|
||||
/// for efficiency reasons.
|
||||
#[cfg(test)]
|
||||
const ALPHA: u64 = 7;
|
||||
#[cfg(test)]
|
||||
const INV_ALPHA: u64 = 10540996611094048183;
|
||||
|
||||
// SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
state[0] = state[0].exp7();
|
||||
state[1] = state[1].exp7();
|
||||
state[2] = state[2].exp7();
|
||||
state[3] = state[3].exp7();
|
||||
state[4] = state[4].exp7();
|
||||
state[5] = state[5].exp7();
|
||||
state[6] = state[6].exp7();
|
||||
state[7] = state[7].exp7();
|
||||
state[8] = state[8].exp7();
|
||||
state[9] = state[9].exp7();
|
||||
state[10] = state[10].exp7();
|
||||
state[11] = state[11].exp7();
|
||||
}
|
||||
|
||||
// INVERSE SBOX FUNCTION
|
||||
// ================================================================================================
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
|
||||
// compute base^10540996611094048183 using 72 multiplications per array element
|
||||
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
|
||||
|
||||
// compute base^10
|
||||
let mut t1 = *state;
|
||||
t1.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100
|
||||
let mut t2 = t1;
|
||||
t2.iter_mut().for_each(|t| *t = t.square());
|
||||
|
||||
// compute base^100100
|
||||
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
|
||||
|
||||
// compute base^100100100100
|
||||
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
|
||||
|
||||
// compute base^100100100100100100100100
|
||||
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
|
||||
|
||||
// compute base^100100100100100100100100100100
|
||||
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
|
||||
|
||||
// compute base^1001001001001001001001001001000100100100100100100100100100100
|
||||
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
|
||||
|
||||
// compute base^1001001001001001001001001001000110110110110110110110110110110111
|
||||
for (i, s) in state.iter_mut().enumerate() {
|
||||
let a = (t7[i].square() * t6[i]).square().square();
|
||||
let b = t1[i] * t2[i] * *s;
|
||||
*s = a * b;
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
|
||||
base: [B; N],
|
||||
tail: [B; N],
|
||||
) -> [B; N] {
|
||||
let mut result = base;
|
||||
for _ in 0..M {
|
||||
result.iter_mut().for_each(|r| *r = r.square());
|
||||
}
|
||||
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
|
||||
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
|
||||
}
|
||||
|
||||
// ROUND CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Rescue round constants;
|
||||
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
|
||||
///
|
||||
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
|
||||
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
|
||||
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(5789762306288267392),
|
||||
Felt::new(6522564764413701783),
|
||||
Felt::new(17809893479458208203),
|
||||
Felt::new(107145243989736508),
|
||||
Felt::new(6388978042437517382),
|
||||
Felt::new(15844067734406016715),
|
||||
Felt::new(9975000513555218239),
|
||||
Felt::new(3344984123768313364),
|
||||
Felt::new(9959189626657347191),
|
||||
Felt::new(12960773468763563665),
|
||||
Felt::new(9602914297752488475),
|
||||
Felt::new(16657542370200465908),
|
||||
],
|
||||
[
|
||||
Felt::new(12987190162843096997),
|
||||
Felt::new(653957632802705281),
|
||||
Felt::new(4441654670647621225),
|
||||
Felt::new(4038207883745915761),
|
||||
Felt::new(5613464648874830118),
|
||||
Felt::new(13222989726778338773),
|
||||
Felt::new(3037761201230264149),
|
||||
Felt::new(16683759727265180203),
|
||||
Felt::new(8337364536491240715),
|
||||
Felt::new(3227397518293416448),
|
||||
Felt::new(8110510111539674682),
|
||||
Felt::new(2872078294163232137),
|
||||
],
|
||||
[
|
||||
Felt::new(18072785500942327487),
|
||||
Felt::new(6200974112677013481),
|
||||
Felt::new(17682092219085884187),
|
||||
Felt::new(10599526828986756440),
|
||||
Felt::new(975003873302957338),
|
||||
Felt::new(8264241093196931281),
|
||||
Felt::new(10065763900435475170),
|
||||
Felt::new(2181131744534710197),
|
||||
Felt::new(6317303992309418647),
|
||||
Felt::new(1401440938888741532),
|
||||
Felt::new(8884468225181997494),
|
||||
Felt::new(13066900325715521532),
|
||||
],
|
||||
[
|
||||
Felt::new(5674685213610121970),
|
||||
Felt::new(5759084860419474071),
|
||||
Felt::new(13943282657648897737),
|
||||
Felt::new(1352748651966375394),
|
||||
Felt::new(17110913224029905221),
|
||||
Felt::new(1003883795902368422),
|
||||
Felt::new(4141870621881018291),
|
||||
Felt::new(8121410972417424656),
|
||||
Felt::new(14300518605864919529),
|
||||
Felt::new(13712227150607670181),
|
||||
Felt::new(17021852944633065291),
|
||||
Felt::new(6252096473787587650),
|
||||
],
|
||||
[
|
||||
Felt::new(4887609836208846458),
|
||||
Felt::new(3027115137917284492),
|
||||
Felt::new(9595098600469470675),
|
||||
Felt::new(10528569829048484079),
|
||||
Felt::new(7864689113198939815),
|
||||
Felt::new(17533723827845969040),
|
||||
Felt::new(5781638039037710951),
|
||||
Felt::new(17024078752430719006),
|
||||
Felt::new(109659393484013511),
|
||||
Felt::new(7158933660534805869),
|
||||
Felt::new(2955076958026921730),
|
||||
Felt::new(7433723648458773977),
|
||||
],
|
||||
[
|
||||
Felt::new(16308865189192447297),
|
||||
Felt::new(11977192855656444890),
|
||||
Felt::new(12532242556065780287),
|
||||
Felt::new(14594890931430968898),
|
||||
Felt::new(7291784239689209784),
|
||||
Felt::new(5514718540551361949),
|
||||
Felt::new(10025733853830934803),
|
||||
Felt::new(7293794580341021693),
|
||||
Felt::new(6728552937464861756),
|
||||
Felt::new(6332385040983343262),
|
||||
Felt::new(13277683694236792804),
|
||||
Felt::new(2600778905124452676),
|
||||
],
|
||||
[
|
||||
Felt::new(7123075680859040534),
|
||||
Felt::new(1034205548717903090),
|
||||
Felt::new(7717824418247931797),
|
||||
Felt::new(3019070937878604058),
|
||||
Felt::new(11403792746066867460),
|
||||
Felt::new(10280580802233112374),
|
||||
Felt::new(337153209462421218),
|
||||
Felt::new(13333398568519923717),
|
||||
Felt::new(3596153696935337464),
|
||||
Felt::new(8104208463525993784),
|
||||
Felt::new(14345062289456085693),
|
||||
Felt::new(17036731477169661256),
|
||||
],
|
||||
];
|
||||
|
||||
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
|
||||
[
|
||||
Felt::new(6077062762357204287),
|
||||
Felt::new(15277620170502011191),
|
||||
Felt::new(5358738125714196705),
|
||||
Felt::new(14233283787297595718),
|
||||
Felt::new(13792579614346651365),
|
||||
Felt::new(11614812331536767105),
|
||||
Felt::new(14871063686742261166),
|
||||
Felt::new(10148237148793043499),
|
||||
Felt::new(4457428952329675767),
|
||||
Felt::new(15590786458219172475),
|
||||
Felt::new(10063319113072092615),
|
||||
Felt::new(14200078843431360086),
|
||||
],
|
||||
[
|
||||
Felt::new(6202948458916099932),
|
||||
Felt::new(17690140365333231091),
|
||||
Felt::new(3595001575307484651),
|
||||
Felt::new(373995945117666487),
|
||||
Felt::new(1235734395091296013),
|
||||
Felt::new(14172757457833931602),
|
||||
Felt::new(707573103686350224),
|
||||
Felt::new(15453217512188187135),
|
||||
Felt::new(219777875004506018),
|
||||
Felt::new(17876696346199469008),
|
||||
Felt::new(17731621626449383378),
|
||||
Felt::new(2897136237748376248),
|
||||
],
|
||||
[
|
||||
Felt::new(8023374565629191455),
|
||||
Felt::new(15013690343205953430),
|
||||
Felt::new(4485500052507912973),
|
||||
Felt::new(12489737547229155153),
|
||||
Felt::new(9500452585969030576),
|
||||
Felt::new(2054001340201038870),
|
||||
Felt::new(12420704059284934186),
|
||||
Felt::new(355990932618543755),
|
||||
Felt::new(9071225051243523860),
|
||||
Felt::new(12766199826003448536),
|
||||
Felt::new(9045979173463556963),
|
||||
Felt::new(12934431667190679898),
|
||||
],
|
||||
[
|
||||
Felt::new(18389244934624494276),
|
||||
Felt::new(16731736864863925227),
|
||||
Felt::new(4440209734760478192),
|
||||
Felt::new(17208448209698888938),
|
||||
Felt::new(8739495587021565984),
|
||||
Felt::new(17000774922218161967),
|
||||
Felt::new(13533282547195532087),
|
||||
Felt::new(525402848358706231),
|
||||
Felt::new(16987541523062161972),
|
||||
Felt::new(5466806524462797102),
|
||||
Felt::new(14512769585918244983),
|
||||
Felt::new(10973956031244051118),
|
||||
],
|
||||
[
|
||||
Felt::new(6982293561042362913),
|
||||
Felt::new(14065426295947720331),
|
||||
Felt::new(16451845770444974180),
|
||||
Felt::new(7139138592091306727),
|
||||
Felt::new(9012006439959783127),
|
||||
Felt::new(14619614108529063361),
|
||||
Felt::new(1394813199588124371),
|
||||
Felt::new(4635111139507788575),
|
||||
Felt::new(16217473952264203365),
|
||||
Felt::new(10782018226466330683),
|
||||
Felt::new(6844229992533662050),
|
||||
Felt::new(7446486531695178711),
|
||||
],
|
||||
[
|
||||
Felt::new(3736792340494631448),
|
||||
Felt::new(577852220195055341),
|
||||
Felt::new(6689998335515779805),
|
||||
Felt::new(13886063479078013492),
|
||||
Felt::new(14358505101923202168),
|
||||
Felt::new(7744142531772274164),
|
||||
Felt::new(16135070735728404443),
|
||||
Felt::new(12290902521256031137),
|
||||
Felt::new(12059913662657709804),
|
||||
Felt::new(16456018495793751911),
|
||||
Felt::new(4571485474751953524),
|
||||
Felt::new(17200392109565783176),
|
||||
],
|
||||
[
|
||||
Felt::new(17130398059294018733),
|
||||
Felt::new(519782857322261988),
|
||||
Felt::new(9625384390925085478),
|
||||
Felt::new(1664893052631119222),
|
||||
Felt::new(7629576092524553570),
|
||||
Felt::new(3485239601103661425),
|
||||
Felt::new(9755891797164033838),
|
||||
Felt::new(15218148195153269027),
|
||||
Felt::new(16460604813734957368),
|
||||
Felt::new(9643968136937729763),
|
||||
Felt::new(3611348709641382851),
|
||||
Felt::new(18256379591337759196),
|
||||
],
|
||||
];
|
646
src/hash/rescue/rpo/digest.rs
Normal file
646
src/hash/rescue/rpo/digest.rs
Normal file
|
@ -0,0 +1,646 @@
|
|||
use alloc::string::String;
|
||||
use core::{
|
||||
cmp::Ordering,
|
||||
fmt::Display,
|
||||
hash::{Hash, Hasher},
|
||||
ops::Deref,
|
||||
slice,
|
||||
};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::{
|
||||
rand::Randomizable,
|
||||
utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
},
|
||||
};
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct RpoDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpoDigest {
|
||||
/// The serialized size of the digest in bytes.
|
||||
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
|
||||
let p = digests.as_ptr();
|
||||
let len = digests.len() * DIGEST_SIZE;
|
||||
unsafe { slice::from_raw_parts(p as *const Felt, len) }
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for RpoDigest {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
state.write(&self.as_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpoDigest {
|
||||
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
let mut result = [0; DIGEST_BYTES];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpoDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpoDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpoDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RpoDigest {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let encoded: String = self.into();
|
||||
write!(f, "{}", encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Randomizable for RpoDigest {
|
||||
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||
if let Some(bytes_array) = bytes_array {
|
||||
Self::try_from(bytes_array).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RpoDigestError {
|
||||
#[error("failed to convert digest field element to {0}")]
|
||||
TypeConversion(&'static str),
|
||||
#[error("failed to convert to field element: {0}")]
|
||||
InvalidFieldElement(String),
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
fn to_bool(v: u64) -> Option<bool> {
|
||||
if v <= 1 {
|
||||
Some(v == 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Ok([
|
||||
to_bool(value.0[0].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[1].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[2].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[3].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpoDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpoDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPO DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&[bool; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[bool; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [bool; DIGEST_SIZE]) -> Self {
|
||||
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u8; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [u8; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u16; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u16; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [u16; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u32; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [u32; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
|
||||
type Error = RpoDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
value[1].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
value[2].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
value[3].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||
// `unwrap`s below are safe
|
||||
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||
|
||||
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||
return Err(HexParseError::OutOfRange);
|
||||
}
|
||||
|
||||
Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpoDigest::try_from)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for RpoDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for RpoDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
|
||||
fn get_size_hint(&self) -> usize {
|
||||
Self::SERIALIZED_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl IntoIterator for RpoDigest {
|
||||
type Item = Felt;
|
||||
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::string::String;
|
||||
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::SliceReader;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpoDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
assert_eq!(bytes.len(), d1.get_size_hint());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpoDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpoDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let string: String = digest.into();
|
||||
let round_trip: RpoDigest = string.try_into().expect("decoding failed");
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpoDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
// BY VALUE
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpoDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpoDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
// BY REF
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpoDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
339
src/hash/rescue/rpo/mod.rs
Normal file
339
src/hash/rescue/rpo/mod.rs
Normal file
|
@ -0,0 +1,339 @@
|
|||
use core::ops::Range;
|
||||
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
|
||||
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
|
||||
INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
|
||||
};
|
||||
|
||||
mod digest;
|
||||
pub use digest::{RpoDigest, RpoDigestError};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is implemented according to the Rescue Prime Optimized
|
||||
/// [specifications](https://eprint.iacr.org/2022/1577) while the padding rule follows the one
|
||||
/// described [here](https://eprint.iacr.org/2023/1045).
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus p = 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Rate size: r = 8 field elements.
|
||||
/// * Capacity size: c = 4 field elements.
|
||||
/// * Number of founds: 7.
|
||||
/// * S-Box degree: 7.
|
||||
///
|
||||
/// The above parameters target a 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
|
||||
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpo256::hash_elements) function rather than hashing the serialized bytes
|
||||
/// using [hash()](Rpo256::hash) function.
|
||||
///
|
||||
/// ## Domain separation
|
||||
/// [merge_in_domain()](Rpo256::merge_in_domain) hashes two digests into one digest with some domain
|
||||
/// identifier and the current implementation sets the second capacity element to the value of
|
||||
/// this domain identifier. Using a similar argument to the one formulated for domain separation of
|
||||
/// the RPX hash function in Appendix C of its [specification](https://eprint.iacr.org/2023/1045),
|
||||
/// one sees that doing so degrades only pre-image resistance, from its initial bound of c.log_2(p),
|
||||
/// by as much as the log_2 of the size of the domain identifier space. Since pre-image resistance
|
||||
/// becomes the bottleneck for the security bound of the sponge in overwrite-mode only when it is
|
||||
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
|
||||
/// the size of the domain identifier space, including for padding, is less than 2^128.
|
||||
///
|
||||
/// ## Hashing of empty input
|
||||
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
|
||||
/// the benefit of requiring no calls to the RPO permutation when hashing empty input.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpo256();
|
||||
|
||||
impl Hasher for Rpo256 {
|
||||
/// Rpo256 collision resistance is 128-bits.
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpoDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// determine the number of field elements needed to encode `bytes` when each field element
|
||||
// represents at most 7 bytes.
|
||||
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
|
||||
|
||||
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
|
||||
// this to achieve:
|
||||
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
|
||||
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
|
||||
state[CAPACITY_RANGE.start] =
|
||||
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
|
||||
// and an additional permutation must be performed.
|
||||
let mut current_chunk_idx = 0_usize;
|
||||
// handle the case of an empty `bytes`
|
||||
let last_chunk_idx = if num_field_elem == 0 {
|
||||
current_chunk_idx
|
||||
} else {
|
||||
num_field_elem - 1
|
||||
};
|
||||
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
|
||||
// copy the chunk into the buffer
|
||||
if current_chunk_idx != last_chunk_idx {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
|
||||
// needed to fill it
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
current_chunk_idx += 1;
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if rate_pos == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
rate_pos + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating the number of field elements constituting the last block when the latter
|
||||
// is not divisible by `RATE_WIDTH`.
|
||||
if rate_pos != 0 {
|
||||
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
Self::hash_elements(Self::Digest::digests_as_elements(values))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element and
|
||||
// set the first capacity element to 5.
|
||||
// - if the value doesn't fit into a single field element, split it into two field elements,
|
||||
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
|
||||
}
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the rate
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpo256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to `elements.len() % RATE_WIDTH`.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
|
||||
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpo256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The number of rounds is set to 7 to target 128-bit security level.
|
||||
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in a RPO round.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the RPO round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the RPO round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpoDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpoDigest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPO permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RESCUE PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPO permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
for i in 0..NUM_ROUNDS {
|
||||
Self::apply_round(state, i);
|
||||
}
|
||||
}
|
||||
|
||||
/// RPO round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// apply first half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
// apply second half of RPO round
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
}
|
387
src/hash/rescue/rpo/tests.rs
Normal file
387
src/hash/rescue/rpo/tests.rs
Normal file
|
@ -0,0 +1,387 @@
|
|||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{
|
||||
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
|
||||
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, STATE_WIDTH, ZERO,
|
||||
};
|
||||
use crate::{
|
||||
hash::rescue::{BINARY_CHUNK_SIZE, CAPACITY_RANGE, RATE_WIDTH},
|
||||
Word, ONE,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_sbox() {
|
||||
let state = [Felt::new(rand_value()); STATE_WIDTH];
|
||||
|
||||
let mut expected = state;
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
apply_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inv_sbox() {
|
||||
let state = [Felt::new(rand_value()); STATE_WIDTH];
|
||||
|
||||
let mut expected = state;
|
||||
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
|
||||
|
||||
let mut actual = state;
|
||||
apply_inv_sbox(&mut actual);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_vs_merge() {
|
||||
let elements = [Felt::new(rand_value()); 8];
|
||||
|
||||
let digests: [RpoDigest; 2] = [
|
||||
RpoDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpoDigest::new(elements[4..].try_into().unwrap()),
|
||||
];
|
||||
|
||||
let m_result = Rpo256::merge(&digests);
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_vs_merge_in_domain() {
|
||||
let elements = [Felt::new(rand_value()); 8];
|
||||
|
||||
let digests: [RpoDigest; 2] = [
|
||||
RpoDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpoDigest::new(elements[4..].try_into().unwrap()),
|
||||
];
|
||||
let merge_result = Rpo256::merge(&digests);
|
||||
|
||||
// ------------- merge with domain = 0 -------------
|
||||
|
||||
// set domain to ZERO. This should not change the result.
|
||||
let domain = ZERO;
|
||||
|
||||
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
|
||||
assert_eq!(merge_result, merge_in_domain_result);
|
||||
|
||||
// ------------- merge with domain = 1 -------------
|
||||
|
||||
// set domain to ONE. This should change the result.
|
||||
let domain = ONE;
|
||||
|
||||
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
|
||||
assert_ne!(merge_result, merge_in_domain_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_vs_merge_with_int() {
|
||||
let tmp = [Felt::new(rand_value()); 4];
|
||||
let seed = RpoDigest::new(tmp);
|
||||
|
||||
// ----- value fits into a field element ------------------------------------------------------
|
||||
let val: Felt = Felt::new(rand_value());
|
||||
let m_result = Rpo256::merge_with_int(seed, val.as_int());
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(val);
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
|
||||
// ----- value does not fit into a field element ----------------------------------------------
|
||||
let val = Felt::MODULUS + 2;
|
||||
let m_result = Rpo256::merge_with_int(seed, val);
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(Felt::new(val));
|
||||
elements.push(ONE);
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_padding() {
|
||||
// adding a zero bytes at the end of a byte string should result in a different hash
|
||||
let r1 = Rpo256::hash(&[1_u8, 2, 3]);
|
||||
let r2 = Rpo256::hash(&[1_u8, 2, 3, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with bigger inputs
|
||||
let r1 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6]);
|
||||
let r2 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with input splitting over two elements
|
||||
let r1 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
|
||||
let r2 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with multiple zeros
|
||||
let r1 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
|
||||
let r2 = Rpo256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_padding_no_extra_permutation_call() {
|
||||
use crate::hash::rescue::DIGEST_RANGE;
|
||||
|
||||
// Implementation
|
||||
let num_bytes = BINARY_CHUNK_SIZE * RATE_WIDTH;
|
||||
let mut buffer = vec![0_u8; num_bytes];
|
||||
*buffer.last_mut().unwrap() = 97;
|
||||
let r1 = Rpo256::hash(&buffer);
|
||||
|
||||
// Expected
|
||||
let final_chunk = [0_u8, 0, 0, 0, 0, 0, 97, 1];
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
// padding when hashing bytes
|
||||
state[CAPACITY_RANGE.start] = Felt::from(RATE_WIDTH as u8);
|
||||
*state.last_mut().unwrap() = Felt::new(u64::from_le_bytes(final_chunk));
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
assert_eq!(&r1[0..4], &state[DIGEST_RANGE]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_padding() {
|
||||
let e1 = [Felt::new(rand_value()); 2];
|
||||
let e2 = [e1[0], e1[1], ZERO];
|
||||
|
||||
let r1 = Rpo256::hash_elements(&e1);
|
||||
let r2 = Rpo256::hash_elements(&e2);
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements() {
|
||||
let elements = [
|
||||
ZERO,
|
||||
ONE,
|
||||
Felt::new(2),
|
||||
Felt::new(3),
|
||||
Felt::new(4),
|
||||
Felt::new(5),
|
||||
Felt::new(6),
|
||||
Felt::new(7),
|
||||
];
|
||||
|
||||
let digests: [RpoDigest; 2] = [
|
||||
RpoDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpoDigest::new(elements[4..8].try_into().unwrap()),
|
||||
];
|
||||
|
||||
let m_result = Rpo256::merge(&digests);
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty() {
|
||||
let elements: Vec<Felt> = vec![];
|
||||
|
||||
let zero_digest = RpoDigest::default();
|
||||
let h_result = Rpo256::hash_elements(&elements);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty_bytes() {
|
||||
let bytes: Vec<u8> = vec![];
|
||||
|
||||
let zero_digest = RpoDigest::default();
|
||||
let h_result = Rpo256::hash(&bytes);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_test_vectors() {
|
||||
let elements = [
|
||||
ZERO,
|
||||
ONE,
|
||||
Felt::new(2),
|
||||
Felt::new(3),
|
||||
Felt::new(4),
|
||||
Felt::new(5),
|
||||
Felt::new(6),
|
||||
Felt::new(7),
|
||||
Felt::new(8),
|
||||
Felt::new(9),
|
||||
Felt::new(10),
|
||||
Felt::new(11),
|
||||
Felt::new(12),
|
||||
Felt::new(13),
|
||||
Felt::new(14),
|
||||
Felt::new(15),
|
||||
Felt::new(16),
|
||||
Felt::new(17),
|
||||
Felt::new(18),
|
||||
];
|
||||
|
||||
for i in 0..elements.len() {
|
||||
let expected = RpoDigest::new(EXPECTED[i]);
|
||||
let result = Rpo256::hash_elements(&elements[..(i + 1)]);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_bytes_with_remainder_length_wont_panic() {
|
||||
// this test targets to assert that no panic will happen with the edge case of having an inputs
|
||||
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
|
||||
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
|
||||
// size.
|
||||
//
|
||||
// this is a preliminary test to the fuzzy-stress of proptest.
|
||||
Rpo256::hash(&[0; 113]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_collision_for_wrapped_field_element() {
|
||||
let a = Rpo256::hash(&[0; 8]);
|
||||
let b = Rpo256::hash(&Felt::MODULUS.to_le_bytes());
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_zeroes_collision() {
|
||||
let mut zeroes = Vec::with_capacity(255);
|
||||
let mut set = BTreeSet::new();
|
||||
(0..255).for_each(|_| {
|
||||
let hash = Rpo256::hash(&zeroes);
|
||||
zeroes.push(0);
|
||||
// panic if a collision was found
|
||||
assert!(set.insert(hash));
|
||||
});
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
|
||||
Rpo256::hash(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
const EXPECTED: [Word; 19] = [
|
||||
[
|
||||
Felt::new(18126731724905382595),
|
||||
Felt::new(7388557040857728717),
|
||||
Felt::new(14290750514634285295),
|
||||
Felt::new(7852282086160480146),
|
||||
],
|
||||
[
|
||||
Felt::new(10139303045932500183),
|
||||
Felt::new(2293916558361785533),
|
||||
Felt::new(15496361415980502047),
|
||||
Felt::new(17904948502382283940),
|
||||
],
|
||||
[
|
||||
Felt::new(17457546260239634015),
|
||||
Felt::new(803990662839494686),
|
||||
Felt::new(10386005777401424878),
|
||||
Felt::new(18168807883298448638),
|
||||
],
|
||||
[
|
||||
Felt::new(13072499238647455740),
|
||||
Felt::new(10174350003422057273),
|
||||
Felt::new(9201651627651151113),
|
||||
Felt::new(6872461887313298746),
|
||||
],
|
||||
[
|
||||
Felt::new(2903803350580990546),
|
||||
Felt::new(1838870750730563299),
|
||||
Felt::new(4258619137315479708),
|
||||
Felt::new(17334260395129062936),
|
||||
],
|
||||
[
|
||||
Felt::new(8571221005243425262),
|
||||
Felt::new(3016595589318175865),
|
||||
Felt::new(13933674291329928438),
|
||||
Felt::new(678640375034313072),
|
||||
],
|
||||
[
|
||||
Felt::new(16314113978986502310),
|
||||
Felt::new(14587622368743051587),
|
||||
Felt::new(2808708361436818462),
|
||||
Felt::new(10660517522478329440),
|
||||
],
|
||||
[
|
||||
Felt::new(2242391899857912644),
|
||||
Felt::new(12689382052053305418),
|
||||
Felt::new(235236990017815546),
|
||||
Felt::new(5046143039268215739),
|
||||
],
|
||||
[
|
||||
Felt::new(5218076004221736204),
|
||||
Felt::new(17169400568680971304),
|
||||
Felt::new(8840075572473868990),
|
||||
Felt::new(12382372614369863623),
|
||||
],
|
||||
[
|
||||
Felt::new(9783834557155203486),
|
||||
Felt::new(12317263104955018849),
|
||||
Felt::new(3933748931816109604),
|
||||
Felt::new(1843043029836917214),
|
||||
],
|
||||
[
|
||||
Felt::new(14498234468286984551),
|
||||
Felt::new(16837257669834682387),
|
||||
Felt::new(6664141123711355107),
|
||||
Felt::new(4590460158294697186),
|
||||
],
|
||||
[
|
||||
Felt::new(4661800562479916067),
|
||||
Felt::new(11794407552792839953),
|
||||
Felt::new(9037742258721863712),
|
||||
Felt::new(6287820818064278819),
|
||||
],
|
||||
[
|
||||
Felt::new(7752693085194633729),
|
||||
Felt::new(7379857372245835536),
|
||||
Felt::new(9270229380648024178),
|
||||
Felt::new(10638301488452560378),
|
||||
],
|
||||
[
|
||||
Felt::new(11542686762698783357),
|
||||
Felt::new(15570714990728449027),
|
||||
Felt::new(7518801014067819501),
|
||||
Felt::new(12706437751337583515),
|
||||
],
|
||||
[
|
||||
Felt::new(9553923701032839042),
|
||||
Felt::new(7281190920209838818),
|
||||
Felt::new(2488477917448393955),
|
||||
Felt::new(5088955350303368837),
|
||||
],
|
||||
[
|
||||
Felt::new(4935426252518736883),
|
||||
Felt::new(12584230452580950419),
|
||||
Felt::new(8762518969632303998),
|
||||
Felt::new(18159875708229758073),
|
||||
],
|
||||
[
|
||||
Felt::new(12795429638314178838),
|
||||
Felt::new(14360248269767567855),
|
||||
Felt::new(3819563852436765058),
|
||||
Felt::new(10859123583999067291),
|
||||
],
|
||||
[
|
||||
Felt::new(2695742617679420093),
|
||||
Felt::new(9151515850666059759),
|
||||
Felt::new(15855828029180595485),
|
||||
Felt::new(17190029785471463210),
|
||||
],
|
||||
[
|
||||
Felt::new(13205273108219124830),
|
||||
Felt::new(2524898486192849221),
|
||||
Felt::new(14618764355375283547),
|
||||
Felt::new(10615614265042186874),
|
||||
],
|
||||
];
|
634
src/hash/rescue/rpx/digest.rs
Normal file
634
src/hash/rescue/rpx/digest.rs
Normal file
|
@ -0,0 +1,634 @@
|
|||
use alloc::string::String;
|
||||
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
|
||||
use crate::{
|
||||
rand::Randomizable,
|
||||
utils::{
|
||||
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
|
||||
DeserializationError, HexParseError, Serializable,
|
||||
},
|
||||
};
|
||||
|
||||
// DIGEST TRAIT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
|
||||
pub struct RpxDigest([Felt; DIGEST_SIZE]);
|
||||
|
||||
impl RpxDigest {
|
||||
/// The serialized size of the digest in bytes.
|
||||
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
||||
pub fn as_elements(&self) -> &[Felt] {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
<Self as Digest>::as_bytes(self)
|
||||
}
|
||||
|
||||
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
|
||||
where
|
||||
I: Iterator<Item = &'a Self>,
|
||||
{
|
||||
digests.flat_map(|d| d.0.iter())
|
||||
}
|
||||
|
||||
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
|
||||
let p = digests.as_ptr();
|
||||
let len = digests.len() * DIGEST_SIZE;
|
||||
unsafe { slice::from_raw_parts(p as *const Felt, len) }
|
||||
}
|
||||
|
||||
/// Returns hexadecimal representation of this digest prefixed with `0x`.
|
||||
pub fn to_hex(&self) -> String {
|
||||
bytes_to_hex_string(self.as_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl Digest for RpxDigest {
|
||||
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
|
||||
let mut result = [0; DIGEST_BYTES];
|
||||
|
||||
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
|
||||
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
|
||||
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
|
||||
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for RpxDigest {
|
||||
type Target = [Felt; DIGEST_SIZE];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for RpxDigest {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// compare the inner u64 of both elements.
|
||||
//
|
||||
// it will iterate the elements and will return the first computation different than
|
||||
// `Equal`. Otherwise, the ordering is equal.
|
||||
//
|
||||
// the endianness is irrelevant here because since, this being a cryptographically secure
|
||||
// hash computation, the digest shouldn't have any ordered property of its input.
|
||||
//
|
||||
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
|
||||
// montgomery reduction for every limb. that is safe because every inner element of the
|
||||
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
|
||||
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
|
||||
Ordering::Equal,
|
||||
|ord, (a, b)| match ord {
|
||||
Ordering::Equal => a.cmp(&b),
|
||||
_ => ord,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for RpxDigest {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RpxDigest {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let encoded: String = self.into();
|
||||
write!(f, "{}", encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Randomizable for RpxDigest {
|
||||
const VALUE_SIZE: usize = DIGEST_BYTES;
|
||||
|
||||
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
|
||||
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
|
||||
if let Some(bytes_array) = bytes_array {
|
||||
Self::try_from(bytes_array).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: FROM RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RpxDigestError {
|
||||
#[error("failed to convert digest field element to {0}")]
|
||||
TypeConversion(&'static str),
|
||||
#[error("failed to convert to field element: {0}")]
|
||||
InvalidFieldElement(String),
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [bool; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
fn to_bool(v: u64) -> Option<bool> {
|
||||
if v <= 1 {
|
||||
Some(v == 1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
Ok([
|
||||
to_bool(value.0[0].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[1].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[2].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
to_bool(value.0[3].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [u8; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [u16; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&RpxDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<RpxDigest> for [u32; DIGEST_SIZE] {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
|
||||
Ok([
|
||||
value.0[0]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
value.0[1]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
value.0[2]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
value.0[3]
|
||||
.as_int()
|
||||
.try_into()
|
||||
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
[
|
||||
value.0[0].as_int(),
|
||||
value.0[1].as_int(),
|
||||
value.0[2].as_int(),
|
||||
value.0[3].as_int(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: &RpxDigest) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpxDigest> for String {
|
||||
/// The returned string starts with `0x`.
|
||||
fn from(value: RpxDigest) -> Self {
|
||||
value.to_hex()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS: TO RPX DIGEST
|
||||
// ================================================================================================
|
||||
|
||||
impl From<&[bool; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[bool; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [bool; DIGEST_SIZE]) -> Self {
|
||||
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u8; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [u8; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u16; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u16; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [u16; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[u32; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
|
||||
(*value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u32; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [u32; DIGEST_SIZE]) -> Self {
|
||||
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
|
||||
type Error = RpxDigestError;
|
||||
|
||||
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
|
||||
Ok(Self([
|
||||
value[0].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
value[1].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
value[2].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
value[3].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
|
||||
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
|
||||
// Note: the input length is known, the conversion from slice to array must succeed so the
|
||||
// `unwrap`s below are safe
|
||||
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
|
||||
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
|
||||
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
|
||||
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
|
||||
|
||||
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
|
||||
return Err(HexParseError::OutOfRange);
|
||||
}
|
||||
|
||||
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[u8]> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
(*value).try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpxDigest::try_from)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for RpxDigest {
|
||||
type Error = HexParseError;
|
||||
|
||||
/// Expects the string to start with `0x`.
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
value.as_str().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION / DESERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for RpxDigest {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_bytes(&self.as_bytes());
|
||||
}
|
||||
|
||||
fn get_size_hint(&self) -> usize {
|
||||
Self::SERIALIZED_SIZE
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxDigest {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
|
||||
for inner in inner.iter_mut() {
|
||||
let e = source.read_u64()?;
|
||||
if e >= Felt::MODULUS {
|
||||
return Err(DeserializationError::InvalidValue(String::from(
|
||||
"Value not in the appropriate range",
|
||||
)));
|
||||
}
|
||||
*inner = Felt::new(e);
|
||||
}
|
||||
|
||||
Ok(Self(inner))
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl IntoIterator for RpxDigest {
|
||||
type Item = Felt;
|
||||
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::string::String;
|
||||
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
|
||||
use crate::utils::SliceReader;
|
||||
|
||||
#[test]
|
||||
fn digest_serialization() {
|
||||
let e1 = Felt::new(rand_value());
|
||||
let e2 = Felt::new(rand_value());
|
||||
let e3 = Felt::new(rand_value());
|
||||
let e4 = Felt::new(rand_value());
|
||||
|
||||
let d1 = RpxDigest([e1, e2, e3, e4]);
|
||||
|
||||
let mut bytes = vec![];
|
||||
d1.write_into(&mut bytes);
|
||||
assert_eq!(DIGEST_BYTES, bytes.len());
|
||||
assert_eq!(bytes.len(), d1.get_size_hint());
|
||||
|
||||
let mut reader = SliceReader::new(&bytes);
|
||||
let d2 = RpxDigest::read_from(&mut reader).unwrap();
|
||||
|
||||
assert_eq!(d1, d2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn digest_encoding() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
let string: String = digest.into();
|
||||
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
|
||||
|
||||
assert_eq!(digest, round_trip);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversions() {
|
||||
let digest = RpxDigest([
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
Felt::new(rand_value()),
|
||||
]);
|
||||
|
||||
// BY VALUE
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = digest.into();
|
||||
let v2: RpxDigest = v.into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = digest.into();
|
||||
let v2: RpxDigest = v.try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
// BY REF
|
||||
// ----------------------------------------------------------------------------------------
|
||||
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
|
||||
|
||||
let v: [u64; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [Felt; DIGEST_SIZE] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).into();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: [u8; DIGEST_BYTES] = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
|
||||
let v: String = (&digest).into();
|
||||
let v2: RpxDigest = (&v).try_into().unwrap();
|
||||
assert_eq!(digest, v2);
|
||||
}
|
||||
}
|
385
src/hash/rescue/rpx/mod.rs
Normal file
385
src/hash/rescue/rpx/mod.rs
Normal file
|
@ -0,0 +1,385 @@
|
|||
use core::ops::Range;
|
||||
|
||||
use super::{
|
||||
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
|
||||
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
|
||||
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
|
||||
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH,
|
||||
ZERO,
|
||||
};
|
||||
|
||||
mod digest;
|
||||
pub use digest::{RpxDigest, RpxDigestError};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub type CubicExtElement = CubeExtension<Felt>;
|
||||
|
||||
// HASHER IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
|
||||
///
|
||||
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
|
||||
///
|
||||
/// The parameters used to instantiate the function are:
|
||||
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
|
||||
/// * State width: 12 field elements.
|
||||
/// * Capacity size: 4 field elements.
|
||||
/// * S-Box degree: 7.
|
||||
/// * Rounds: There are 3 different types of rounds:
|
||||
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` →
|
||||
/// `apply_inv_sbox`.
|
||||
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension
|
||||
/// field).
|
||||
/// - (M): `apply_mds` → `add_constants`.
|
||||
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
|
||||
///
|
||||
/// The above parameters target a 128-bit security level. The digest consists of four field elements
|
||||
/// and it can be serialized into 32 bytes (256 bits).
|
||||
///
|
||||
/// ## Hash output consistency
|
||||
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
|
||||
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
|
||||
/// a hash for the same set of elements using these functions will always produce the same
|
||||
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
|
||||
/// same result as hashing 8 elements which make up these digests using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function.
|
||||
///
|
||||
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
|
||||
/// For example, if we take two field elements, serialize them to bytes and hash them using
|
||||
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
|
||||
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
|
||||
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
|
||||
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
|
||||
/// deserialization procedure used by this function is different from the procedure used to
|
||||
/// deserialize valid field elements.
|
||||
///
|
||||
/// Thus, if the underlying data consists of valid field elements, it might make more sense
|
||||
/// to deserialize them into field elements and then hash them using
|
||||
/// [hash_elements()](Rpx256::hash_elements) function rather than hashing the serialized bytes
|
||||
/// using [hash()](Rpx256::hash) function.
|
||||
///
|
||||
/// ## Domain separation
|
||||
/// [merge_in_domain()](Rpx256::merge_in_domain) hashes two digests into one digest with some domain
|
||||
/// identifier and the current implementation sets the second capacity element to the value of
|
||||
/// this domain identifier. Using a similar argument to the one formulated for domain separation
|
||||
/// in Appendix C of the [specifications](https://eprint.iacr.org/2023/1045), one sees that doing
|
||||
/// so degrades only pre-image resistance, from its initial bound of c.log_2(p), by as much as
|
||||
/// the log_2 of the size of the domain identifier space. Since pre-image resistance becomes
|
||||
/// the bottleneck for the security bound of the sponge in overwrite-mode only when it is
|
||||
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
|
||||
/// the size of the domain identifier space, including for padding, is less than 2^128.
|
||||
///
|
||||
/// ## Hashing of empty input
|
||||
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
|
||||
/// the benefit of requiring no calls to the RPX permutation when hashing empty input.
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
pub struct Rpx256();
|
||||
|
||||
impl Hasher for Rpx256 {
|
||||
/// Rpx256 collision resistance is 128-bits.
|
||||
const COLLISION_RESISTANCE: u32 = 128;
|
||||
|
||||
type Digest = RpxDigest;
|
||||
|
||||
fn hash(bytes: &[u8]) -> Self::Digest {
|
||||
// initialize the state with zeroes
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
// determine the number of field elements needed to encode `bytes` when each field element
|
||||
// represents at most 7 bytes.
|
||||
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
|
||||
|
||||
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
|
||||
// this to achieve:
|
||||
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
|
||||
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
|
||||
state[CAPACITY_RANGE.start] =
|
||||
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
|
||||
|
||||
// initialize a buffer to receive the little-endian elements.
|
||||
let mut buf = [0_u8; 8];
|
||||
|
||||
// iterate the chunks of bytes, creating a field element from each chunk and copying it
|
||||
// into the state.
|
||||
//
|
||||
// every time the rate range is filled, a permutation is performed. if the final value of
|
||||
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
|
||||
// and an additional permutation must be performed.
|
||||
let mut current_chunk_idx = 0_usize;
|
||||
// handle the case of an empty `bytes`
|
||||
let last_chunk_idx = if num_field_elem == 0 {
|
||||
current_chunk_idx
|
||||
} else {
|
||||
num_field_elem - 1
|
||||
};
|
||||
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
|
||||
// copy the chunk into the buffer
|
||||
if current_chunk_idx != last_chunk_idx {
|
||||
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
|
||||
} else {
|
||||
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
|
||||
// needed to fill it
|
||||
buf.fill(0);
|
||||
buf[..chunk.len()].copy_from_slice(chunk);
|
||||
buf[chunk.len()] = 1;
|
||||
}
|
||||
current_chunk_idx += 1;
|
||||
|
||||
// set the current rate element to the input. since we take at most 7 bytes, we are
|
||||
// guaranteed that the inputs data will fit into a single field element.
|
||||
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
|
||||
|
||||
// proceed filling the range. if it's full, then we apply a permutation and reset the
|
||||
// counter to the beginning of the range.
|
||||
if rate_pos == RATE_WIDTH - 1 {
|
||||
Self::apply_permutation(&mut state);
|
||||
0
|
||||
} else {
|
||||
rate_pos + 1
|
||||
}
|
||||
});
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
|
||||
// don't need to apply any extra padding because the first capacity element contains a
|
||||
// flag indicating the number of field elements constituting the last block when the latter
|
||||
// is not divisible by `RATE_WIDTH`.
|
||||
if rate_pos != 0 {
|
||||
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the rate as hash result.
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = Self::Digest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
|
||||
Self::hash_elements(Self::Digest::digests_as_elements(values))
|
||||
}
|
||||
|
||||
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
|
||||
// initialize the state as follows:
|
||||
// - seed is copied into the first 4 elements of the rate portion of the state.
|
||||
// - if the value fits into a single field element, copy it into the fifth rate element and
|
||||
// set the first capacity element to 5.
|
||||
// - if the value doesn't fit into a single field element, split it into two field elements,
|
||||
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
|
||||
state[INPUT2_RANGE.start] = Felt::new(value);
|
||||
if value < Felt::MODULUS {
|
||||
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
|
||||
} else {
|
||||
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
|
||||
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
|
||||
}
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the rate
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementHasher for Rpx256 {
|
||||
type BaseField = Felt;
|
||||
|
||||
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
|
||||
// convert the elements into a list of base field elements
|
||||
let elements = E::slice_as_base_elements(elements);
|
||||
|
||||
// initialize state to all zeros, except for the first element of the capacity part, which
|
||||
// is set to `elements.len() % RATE_WIDTH`.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
|
||||
|
||||
// absorb elements into the state one by one until the rate portion of the state is filled
|
||||
// up; then apply the Rescue permutation and start absorbing again; repeat until all
|
||||
// elements have been absorbed
|
||||
let mut i = 0;
|
||||
for &element in elements.iter() {
|
||||
state[RATE_RANGE.start + i] = element;
|
||||
i += 1;
|
||||
if i % RATE_WIDTH == 0 {
|
||||
Self::apply_permutation(&mut state);
|
||||
i = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// if we absorbed some elements but didn't apply a permutation to them (would happen when
|
||||
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
|
||||
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
|
||||
if i > 0 {
|
||||
while i != RATE_WIDTH {
|
||||
state[RATE_RANGE.start + i] = ZERO;
|
||||
i += 1;
|
||||
}
|
||||
Self::apply_permutation(&mut state);
|
||||
}
|
||||
|
||||
// return the first 4 elements of the state as hash result
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
// HASH FUNCTION IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Rpx256 {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
|
||||
/// the remaining 4 elements are reserved for capacity.
|
||||
pub const STATE_WIDTH: usize = STATE_WIDTH;
|
||||
|
||||
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
|
||||
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
|
||||
|
||||
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
|
||||
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
|
||||
|
||||
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
|
||||
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
|
||||
|
||||
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
|
||||
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
|
||||
|
||||
/// Round constants added to the hasher state in the first half of the round.
|
||||
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
|
||||
|
||||
/// Round constants added to the hasher state in the second half of the round.
|
||||
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
|
||||
|
||||
// TRAIT PASS-THROUGH FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of the provided sequence of bytes.
|
||||
#[inline(always)]
|
||||
pub fn hash(bytes: &[u8]) -> RpxDigest {
|
||||
<Self as Hasher>::hash(bytes)
|
||||
}
|
||||
|
||||
/// Returns a hash of two digests. This method is intended for use in construction of
|
||||
/// Merkle trees and verification of Merkle paths.
|
||||
#[inline(always)]
|
||||
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
|
||||
<Self as Hasher>::merge(values)
|
||||
}
|
||||
|
||||
/// Returns a hash of the provided field elements.
|
||||
#[inline(always)]
|
||||
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
|
||||
<Self as ElementHasher>::hash_elements(elements)
|
||||
}
|
||||
|
||||
// DOMAIN IDENTIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a hash of two digests and a domain identifier.
|
||||
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
|
||||
// initialize the state by copying the digest elements into the rate portion of the state
|
||||
// (8 total elements), and set the capacity elements to 0.
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
let it = RpxDigest::digests_as_elements_iter(values.iter());
|
||||
for (i, v) in it.enumerate() {
|
||||
state[RATE_RANGE.start + i] = *v;
|
||||
}
|
||||
|
||||
// set the second capacity element to the domain value. The first capacity element is used
|
||||
// for padding purposes.
|
||||
state[CAPACITY_RANGE.start + 1] = domain;
|
||||
|
||||
// apply the RPX permutation and return the first four elements of the state
|
||||
Self::apply_permutation(&mut state);
|
||||
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
|
||||
}
|
||||
|
||||
// RPX PERMUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Applies RPX permutation to the provided state.
|
||||
#[inline(always)]
|
||||
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
|
||||
Self::apply_fb_round(state, 0);
|
||||
Self::apply_ext_round(state, 1);
|
||||
Self::apply_fb_round(state, 2);
|
||||
Self::apply_ext_round(state, 3);
|
||||
Self::apply_fb_round(state, 4);
|
||||
Self::apply_ext_round(state, 5);
|
||||
Self::apply_final_round(state, 6);
|
||||
}
|
||||
|
||||
// RPX PERMUTATION ROUND FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// (FB) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
|
||||
add_constants(state, &ARK1[round]);
|
||||
apply_sbox(state);
|
||||
}
|
||||
|
||||
apply_mds(state);
|
||||
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
|
||||
add_constants(state, &ARK2[round]);
|
||||
apply_inv_sbox(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// (E) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
// add constants
|
||||
add_constants(state, &ARK1[round]);
|
||||
|
||||
// decompose the state into 4 elements in the cubic extension field and apply the power 7
|
||||
// map to each of the elements
|
||||
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
|
||||
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
|
||||
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
|
||||
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
|
||||
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
|
||||
|
||||
// decompose the state back into 12 base field elements
|
||||
let arr_ext = [ext0, ext1, ext2, ext3];
|
||||
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
|
||||
.try_into()
|
||||
.expect("shouldn't fail");
|
||||
}
|
||||
|
||||
/// (M) round function.
|
||||
#[inline(always)]
|
||||
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
|
||||
apply_mds(state);
|
||||
add_constants(state, &ARK1[round]);
|
||||
}
|
||||
|
||||
/// Computes an exponentiation to the power 7 in cubic extension field.
|
||||
#[inline(always)]
|
||||
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
|
||||
let x2 = x.square();
|
||||
let x4 = x2.square();
|
||||
|
||||
let x3 = x2 * x;
|
||||
x3 * x4
|
||||
}
|
||||
}
|
186
src/hash/rescue/rpx/tests.rs
Normal file
186
src/hash/rescue/rpx/tests.rs
Normal file
|
@ -0,0 +1,186 @@
|
|||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use proptest::prelude::*;
|
||||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Felt, Hasher, Rpx256, StarkField, ZERO};
|
||||
use crate::{hash::rescue::RpxDigest, ONE};
|
||||
|
||||
#[test]
|
||||
fn hash_elements_vs_merge() {
|
||||
let elements = [Felt::new(rand_value()); 8];
|
||||
|
||||
let digests: [RpxDigest; 2] = [
|
||||
RpxDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpxDigest::new(elements[4..].try_into().unwrap()),
|
||||
];
|
||||
|
||||
let m_result = Rpx256::merge(&digests);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_vs_merge_in_domain() {
|
||||
let elements = [Felt::new(rand_value()); 8];
|
||||
|
||||
let digests: [RpxDigest; 2] = [
|
||||
RpxDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpxDigest::new(elements[4..].try_into().unwrap()),
|
||||
];
|
||||
let merge_result = Rpx256::merge(&digests);
|
||||
|
||||
// ----- merge with domain = 0 ----------------------------------------------------------------
|
||||
|
||||
// set domain to ZERO. This should not change the result.
|
||||
let domain = ZERO;
|
||||
|
||||
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
|
||||
assert_eq!(merge_result, merge_in_domain_result);
|
||||
|
||||
// ----- merge with domain = 1 ----------------------------------------------------------------
|
||||
|
||||
// set domain to ONE. This should change the result.
|
||||
let domain = ONE;
|
||||
|
||||
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
|
||||
assert_ne!(merge_result, merge_in_domain_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_vs_merge_with_int() {
|
||||
let tmp = [Felt::new(rand_value()); 4];
|
||||
let seed = RpxDigest::new(tmp);
|
||||
|
||||
// ----- value fits into a field element ------------------------------------------------------
|
||||
let val: Felt = Felt::new(rand_value());
|
||||
let m_result = Rpx256::merge_with_int(seed, val.as_int());
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(val);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
|
||||
// ----- value does not fit into a field element ----------------------------------------------
|
||||
let val = Felt::MODULUS + 2;
|
||||
let m_result = Rpx256::merge_with_int(seed, val);
|
||||
|
||||
let mut elements = seed.as_elements().to_vec();
|
||||
elements.push(Felt::new(val));
|
||||
elements.push(ONE);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_padding() {
|
||||
// adding a zero bytes at the end of a byte string should result in a different hash
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with bigger inputs
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with input splitting over two elements
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
|
||||
// same as above but with multiple zeros
|
||||
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
|
||||
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements_padding() {
|
||||
let e1 = [Felt::new(rand_value()); 2];
|
||||
let e2 = [e1[0], e1[1], ZERO];
|
||||
|
||||
let r1 = Rpx256::hash_elements(&e1);
|
||||
let r2 = Rpx256::hash_elements(&e2);
|
||||
assert_ne!(r1, r2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_elements() {
|
||||
let elements = [
|
||||
ZERO,
|
||||
ONE,
|
||||
Felt::new(2),
|
||||
Felt::new(3),
|
||||
Felt::new(4),
|
||||
Felt::new(5),
|
||||
Felt::new(6),
|
||||
Felt::new(7),
|
||||
];
|
||||
|
||||
let digests: [RpxDigest; 2] = [
|
||||
RpxDigest::new(elements[..4].try_into().unwrap()),
|
||||
RpxDigest::new(elements[4..8].try_into().unwrap()),
|
||||
];
|
||||
|
||||
let m_result = Rpx256::merge(&digests);
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
assert_eq!(m_result, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty() {
|
||||
let elements: Vec<Felt> = vec![];
|
||||
|
||||
let zero_digest = RpxDigest::default();
|
||||
let h_result = Rpx256::hash_elements(&elements);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_empty_bytes() {
|
||||
let bytes: Vec<u8> = vec![];
|
||||
|
||||
let zero_digest = RpxDigest::default();
|
||||
let h_result = Rpx256::hash(&bytes);
|
||||
assert_eq!(zero_digest, h_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_bytes_with_remainder_length_wont_panic() {
|
||||
// this test targets to assert that no panic will happen with the edge case of having an inputs
|
||||
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
|
||||
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
|
||||
// size.
|
||||
//
|
||||
// this is a preliminary test to the fuzzy-stress of proptest.
|
||||
Rpx256::hash(&[0; 113]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_collision_for_wrapped_field_element() {
|
||||
let a = Rpx256::hash(&[0; 8]);
|
||||
let b = Rpx256::hash(&Felt::MODULUS.to_le_bytes());
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sponge_zeroes_collision() {
|
||||
let mut zeroes = Vec::with_capacity(255);
|
||||
let mut set = BTreeSet::new();
|
||||
(0..255).for_each(|_| {
|
||||
let hash = Rpx256::hash(&zeroes);
|
||||
zeroes.push(0);
|
||||
// panic if a collision was found
|
||||
assert!(set.insert(hash));
|
||||
});
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
|
||||
Rpx256::hash(bytes);
|
||||
}
|
||||
}
|
10
src/hash/rescue/tests.rs
Normal file
10
src/hash/rescue/tests.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
use rand_utils::rand_value;
|
||||
|
||||
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
|
||||
|
||||
#[test]
|
||||
fn test_alphas() {
|
||||
let e: Felt = Felt::new(rand_value());
|
||||
let e_exp = e.exp(ALPHA);
|
||||
assert_eq!(e, e_exp.exp(INV_ALPHA));
|
||||
}
|
71
src/lib.rs
Normal file
71
src/lib.rs
Normal file
|
@ -0,0 +1,71 @@
|
|||
#![no_std]
|
||||
|
||||
#[macro_use]
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
extern crate std;
|
||||
|
||||
pub mod dsa;
|
||||
pub mod hash;
|
||||
pub mod merkle;
|
||||
pub mod rand;
|
||||
pub mod utils;
|
||||
|
||||
// RE-EXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
pub use winter_math::{
|
||||
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
|
||||
FieldElement, StarkField,
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
/// A group of four field elements in the Miden base field.
|
||||
pub type Word = [Felt; WORD_SIZE];
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Number of field elements in a word.
|
||||
pub const WORD_SIZE: usize = 4;
|
||||
|
||||
/// Field element representing ZERO in the Miden base filed.
|
||||
pub const ZERO: Felt = Felt::ZERO;
|
||||
|
||||
/// Field element representing ONE in the Miden base filed.
|
||||
pub const ONE: Felt = Felt::ONE;
|
||||
|
||||
/// Array of field elements representing word of ZEROs in the Miden base field.
|
||||
pub const EMPTY_WORD: [Felt; 4] = [ZERO; WORD_SIZE];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn debug_assert_is_checked() {
|
||||
// enforce the release checks to always have `RUSTFLAGS="-C debug-assertions".
|
||||
//
|
||||
// some upstream tests are performed with `debug_assert`, and we want to assert its correctness
|
||||
// downstream.
|
||||
//
|
||||
// for reference, check
|
||||
// https://github.com/0xPolygonMiden/miden-vm/issues/433
|
||||
debug_assert!(false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
#[allow(arithmetic_overflow)]
|
||||
fn overflow_panics_for_test() {
|
||||
// overflows might be disabled if tests are performed in release mode. these are critical,
|
||||
// mandatory checks as overflows might be attack vectors.
|
||||
//
|
||||
// to enable overflow checks in release mode, ensure `RUSTFLAGS="-C overflow-checks"`
|
||||
let a = 1_u64;
|
||||
let b = 64;
|
||||
assert_ne!(a << b, 0);
|
||||
}
|
219
src/main.rs
Normal file
219
src/main.rs
Normal file
|
@ -0,0 +1,219 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use clap::Parser;
|
||||
use miden_crypto::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
merkle::{MerkleError, Smt},
|
||||
Felt, Word, EMPTY_WORD, ONE,
|
||||
};
|
||||
use rand::{prelude::IteratorRandom, thread_rng, Rng};
|
||||
use rand_utils::rand_value;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
|
||||
pub struct BenchmarkCmd {
|
||||
/// Size of the tree
|
||||
#[clap(short = 's', long = "size", default_value = "1000000")]
|
||||
size: usize,
|
||||
/// Number of insertions
|
||||
#[clap(short = 'i', long = "insertions", default_value = "1000")]
|
||||
insertions: usize,
|
||||
/// Number of updates
|
||||
#[clap(short = 'u', long = "updates", default_value = "1000")]
|
||||
updates: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
benchmark_smt();
|
||||
}
|
||||
|
||||
/// Run a benchmark for [`Smt`].
|
||||
pub fn benchmark_smt() {
|
||||
let args = BenchmarkCmd::parse();
|
||||
let tree_size = args.size;
|
||||
let insertions = args.insertions;
|
||||
let updates = args.updates;
|
||||
|
||||
assert!(updates <= tree_size, "Cannot update more than `size`");
|
||||
// prepare the `leaves` vector for tree creation
|
||||
let mut entries = Vec::new();
|
||||
for i in 0..tree_size {
|
||||
let key = rand_value::<RpoDigest>();
|
||||
let value = [ONE, ONE, ONE, Felt::new(i as u64)];
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
let mut tree = construction(entries.clone(), tree_size).unwrap();
|
||||
insertion(&mut tree.clone(), insertions).unwrap();
|
||||
batched_insertion(&mut tree.clone(), insertions).unwrap();
|
||||
batched_update(&mut tree.clone(), entries, updates).unwrap();
|
||||
proof_generation(&mut tree).unwrap();
|
||||
}
|
||||
|
||||
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
|
||||
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt, MerkleError> {
|
||||
println!("Running a construction benchmark:");
|
||||
let now = Instant::now();
|
||||
let tree = Smt::with_entries(entries)?;
|
||||
let elapsed = now.elapsed().as_secs_f32();
|
||||
println!("Constructed an SMT with {size} key-value pairs in {elapsed:.1} seconds");
|
||||
println!("Number of leaf nodes: {}\n", tree.leaves().count());
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Runs the insertion benchmark for the [`Smt`].
|
||||
pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
|
||||
println!("Running an insertion benchmark:");
|
||||
|
||||
let size = tree.num_leaves();
|
||||
let mut insertion_times = Vec::new();
|
||||
|
||||
for i in 0..insertions {
|
||||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
||||
|
||||
let now = Instant::now();
|
||||
tree.insert(test_key, test_value);
|
||||
let elapsed = now.elapsed();
|
||||
insertion_times.push(elapsed.as_micros());
|
||||
}
|
||||
|
||||
println!(
|
||||
"The average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n",
|
||||
// calculate the average
|
||||
insertion_times.iter().sum::<u128>() as f64 / (insertions as f64),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
|
||||
println!("Running a batched insertion benchmark:");
|
||||
|
||||
let size = tree.num_leaves();
|
||||
|
||||
let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions)
|
||||
.map(|i| {
|
||||
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let now = Instant::now();
|
||||
let mutations = tree.compute_mutations(new_pairs);
|
||||
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||
|
||||
println!(
|
||||
"The average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||
compute_elapsed,
|
||||
compute_elapsed * 1000_f64 / insertions as f64, // time in μs
|
||||
);
|
||||
|
||||
let now = Instant::now();
|
||||
tree.apply_mutations(mutations)?;
|
||||
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||
|
||||
println!(
|
||||
"The average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||
apply_elapsed,
|
||||
apply_elapsed * 1000_f64 / insertions as f64, // time in μs
|
||||
);
|
||||
|
||||
println!(
|
||||
"The average batch insertion time measured by a {insertions}-batch into an SMT with {size} leaves totals to {:.1} ms",
|
||||
(compute_elapsed + apply_elapsed),
|
||||
);
|
||||
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn batched_update(
|
||||
tree: &mut Smt,
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
updates: usize,
|
||||
) -> Result<(), MerkleError> {
|
||||
const REMOVAL_PROBABILITY: f64 = 0.2;
|
||||
|
||||
println!("Running a batched update benchmark:");
|
||||
|
||||
let size = tree.num_leaves();
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let new_pairs =
|
||||
entries
|
||||
.into_iter()
|
||||
.choose_multiple(&mut rng, updates)
|
||||
.into_iter()
|
||||
.map(|(key, _)| {
|
||||
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
|
||||
EMPTY_WORD
|
||||
} else {
|
||||
[ONE, ONE, ONE, Felt::new(rng.gen())]
|
||||
};
|
||||
|
||||
(key, value)
|
||||
});
|
||||
|
||||
assert_eq!(new_pairs.len(), updates);
|
||||
|
||||
let now = Instant::now();
|
||||
let mutations = tree.compute_mutations(new_pairs);
|
||||
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||
|
||||
let now = Instant::now();
|
||||
tree.apply_mutations(mutations)?;
|
||||
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||
|
||||
println!(
|
||||
"The average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||
compute_elapsed,
|
||||
compute_elapsed * 1000_f64 / updates as f64, // time in μs
|
||||
);
|
||||
|
||||
println!(
|
||||
"The average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||
apply_elapsed,
|
||||
apply_elapsed * 1000_f64 / updates as f64, // time in μs
|
||||
);
|
||||
|
||||
println!(
|
||||
"The average batch update time measured by a {updates}-batch into an SMT with {size} leaves totals to {:.1} ms",
|
||||
(compute_elapsed + apply_elapsed),
|
||||
);
|
||||
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Runs the proof generation benchmark for the [`Smt`].
|
||||
pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> {
|
||||
const NUM_PROOFS: usize = 100;
|
||||
|
||||
println!("Running a proof generation benchmark:");
|
||||
|
||||
let mut insertion_times = Vec::new();
|
||||
let size = tree.num_leaves();
|
||||
|
||||
for i in 0..NUM_PROOFS {
|
||||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
||||
tree.insert(test_key, test_value);
|
||||
|
||||
let now = Instant::now();
|
||||
let _proof = tree.open(&test_key);
|
||||
insertion_times.push(now.elapsed().as_micros());
|
||||
}
|
||||
|
||||
println!(
|
||||
"The average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
|
||||
// calculate the average
|
||||
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
1617
src/merkle/empty_roots.rs
Normal file
1617
src/merkle/empty_roots.rs
Normal file
File diff suppressed because it is too large
Load diff
36
src/merkle/error.rs
Normal file
36
src/merkle/error.rs
Normal file
|
@ -0,0 +1,36 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use super::{NodeIndex, RpoDigest};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MerkleError {
|
||||
#[error("expected merkle root {expected_root} found {actual_root}")]
|
||||
ConflictingRoots {
|
||||
expected_root: RpoDigest,
|
||||
actual_root: RpoDigest,
|
||||
},
|
||||
#[error("provided merkle tree depth {0} is too small")]
|
||||
DepthTooSmall(u8),
|
||||
#[error("provided merkle tree depth {0} is too big")]
|
||||
DepthTooBig(u64),
|
||||
#[error("multiple values provided for merkle tree index {0}")]
|
||||
DuplicateValuesForIndex(u64),
|
||||
#[error("node index value {value} is not valid for depth {depth}")]
|
||||
InvalidNodeIndex { depth: u8, value: u64 },
|
||||
#[error("provided node index depth {provided} does not match expected depth {expected}")]
|
||||
InvalidNodeIndexDepth { expected: u8, provided: u8 },
|
||||
#[error("merkle subtree depth {subtree_depth} exceeds merkle tree depth {tree_depth}")]
|
||||
SubtreeDepthExceedsDepth { subtree_depth: u8, tree_depth: u8 },
|
||||
#[error("number of entries in the merkle tree exceeds the maximum of {0}")]
|
||||
TooManyEntries(usize),
|
||||
#[error("node index `{0}` not found in the tree")]
|
||||
NodeIndexNotFoundInTree(NodeIndex),
|
||||
#[error("node {0:?} with index `{1}` not found in the store")]
|
||||
NodeIndexNotFoundInStore(RpoDigest, NodeIndex),
|
||||
#[error("number of provided merkle tree leaves {0} is not a power of two")]
|
||||
NumLeavesNotPowerOfTwo(usize),
|
||||
#[error("root {0:?} is not in the store")]
|
||||
RootNotInStore(RpoDigest),
|
||||
#[error("partial smt does not track the merkle path for key {0} so updating it would produce a different root compared to the same update in the full tree")]
|
||||
UntrackedKey(RpoDigest),
|
||||
}
|
244
src/merkle/index.rs
Normal file
244
src/merkle/index.rs
Normal file
|
@ -0,0 +1,244 @@
|
|||
use core::fmt::Display;
|
||||
|
||||
use super::{Felt, MerkleError, RpoDigest};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
// NODE INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// Address to an arbitrary node in a binary tree using level order form.
|
||||
///
|
||||
/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
|
||||
/// are numbered from $0..(2^d)-1$. Example:
|
||||
///
|
||||
/// ```ignore
|
||||
/// depth
|
||||
/// 0 0
|
||||
/// 1 0 1
|
||||
/// 2 0 1 2 3
|
||||
/// 3 0 1 2 3 4 5 6 7
|
||||
/// ```
|
||||
///
|
||||
/// The root is represented by the pair $(0, 0)$, its left child is $(1, 0)$ and its right child
|
||||
/// $(1, 1)$.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct NodeIndex {
|
||||
depth: u8,
|
||||
value: u64,
|
||||
}
|
||||
|
||||
impl NodeIndex {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates a new node index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the `value` is greater than or equal to 2^{depth}.
|
||||
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
|
||||
if (64 - value.leading_zeros()) > depth as u32 {
|
||||
Err(MerkleError::InvalidNodeIndex { depth, value })
|
||||
} else {
|
||||
Ok(Self { depth, value })
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new node index without checking its validity.
|
||||
pub const fn new_unchecked(depth: u8, value: u64) -> Self {
|
||||
debug_assert!((64 - value.leading_zeros()) <= depth as u32);
|
||||
Self { depth, value }
|
||||
}
|
||||
|
||||
/// Creates a new node index for testing purposes.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the `value` is greater than or equal to 2^{depth}.
|
||||
#[cfg(test)]
|
||||
pub fn make(depth: u8, value: u64) -> Self {
|
||||
Self::new(depth, value).unwrap()
|
||||
}
|
||||
|
||||
/// Creates a node index from a pair of field elements representing the depth and value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - `depth` doesn't fit in a `u8`.
|
||||
/// - `value` is greater than or equal to 2^{depth}.
|
||||
pub fn from_elements(depth: &Felt, value: &Felt) -> Result<Self, MerkleError> {
|
||||
let depth = depth.as_int();
|
||||
let depth = u8::try_from(depth).map_err(|_| MerkleError::DepthTooBig(depth))?;
|
||||
let value = value.as_int();
|
||||
Self::new(depth, value)
|
||||
}
|
||||
|
||||
/// Creates a new node index pointing to the root of the tree.
|
||||
pub const fn root() -> Self {
|
||||
Self { depth: 0, value: 0 }
|
||||
}
|
||||
|
||||
/// Computes sibling index of the current node.
|
||||
pub const fn sibling(mut self) -> Self {
|
||||
self.value ^= 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns left child index of the current node.
|
||||
pub const fn left_child(mut self) -> Self {
|
||||
self.depth += 1;
|
||||
self.value <<= 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns right child index of the current node.
|
||||
pub const fn right_child(mut self) -> Self {
|
||||
self.depth += 1;
|
||||
self.value = (self.value << 1) + 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
|
||||
/// a new value instead of mutating `self`.
|
||||
pub const fn parent(mut self) -> Self {
|
||||
self.depth = self.depth.saturating_sub(1);
|
||||
self.value >>= 1;
|
||||
self
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Builds a node to be used as input of a hash function when computing a Merkle path.
|
||||
///
|
||||
/// Will evaluate the parity of the current instance to define the result.
|
||||
pub const fn build_node(&self, slf: RpoDigest, sibling: RpoDigest) -> [RpoDigest; 2] {
|
||||
if self.is_value_odd() {
|
||||
[sibling, slf]
|
||||
} else {
|
||||
[slf, sibling]
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the scalar representation of the depth/value pair.
|
||||
///
|
||||
/// It is computed as `2^depth + value`.
|
||||
pub const fn to_scalar_index(&self) -> u64 {
|
||||
(1 << self.depth as u64) + self.value
|
||||
}
|
||||
|
||||
/// Returns the depth of the current instance.
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the value of this index.
|
||||
pub const fn value(&self) -> u64 {
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Returns `true` if the current instance points to a right sibling node.
|
||||
pub const fn is_value_odd(&self) -> bool {
|
||||
(self.value & 1) == 1
|
||||
}
|
||||
|
||||
/// Returns `true` if the depth is `0`.
|
||||
pub const fn is_root(&self) -> bool {
|
||||
self.depth == 0
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Traverses one level towards the root, decrementing the depth by `1`.
|
||||
pub fn move_up(&mut self) {
|
||||
self.depth = self.depth.saturating_sub(1);
|
||||
self.value >>= 1;
|
||||
}
|
||||
|
||||
/// Traverses towards the root until the specified depth is reached.
|
||||
///
|
||||
/// Assumes that the specified depth is smaller than the current depth.
|
||||
pub fn move_up_to(&mut self, depth: u8) {
|
||||
debug_assert!(depth < self.depth);
|
||||
let delta = self.depth.saturating_sub(depth);
|
||||
self.depth = self.depth.saturating_sub(delta);
|
||||
self.value >>= delta as u32;
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for NodeIndex {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "depth={}, value={}", self.depth, self.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for NodeIndex {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u8(self.depth);
|
||||
target.write_u64(self.value);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for NodeIndex {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let depth = source.read_u8()?;
|
||||
let value = source.read_u64()?;
|
||||
NodeIndex::new(depth, value)
|
||||
.map_err(|_| DeserializationError::InvalidValue("Invalid index".into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use assert_matches::assert_matches;
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_node_index_value_too_high() {
|
||||
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
|
||||
let err = NodeIndex::new(0, 1).unwrap_err();
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
|
||||
|
||||
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
|
||||
let err = NodeIndex::new(1, 2).unwrap_err();
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
|
||||
|
||||
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
|
||||
let err = NodeIndex::new(2, 4).unwrap_err();
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
|
||||
|
||||
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
|
||||
let err = NodeIndex::new(3, 8).unwrap_err();
|
||||
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_index_can_represent_depth_64() {
|
||||
assert!(NodeIndex::new(64, u64::MAX).is_ok());
|
||||
}
|
||||
|
||||
prop_compose! {
|
||||
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
|
||||
// unwrap never panics because the range of depth is 0..u64::BITS
|
||||
let mut depth = value.ilog2() as u8;
|
||||
if value > (1 << depth) { // round up
|
||||
depth += 1;
|
||||
}
|
||||
NodeIndex::new(depth, value).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn arbitrary_index_wont_panic_on_move_up(
|
||||
mut index in node_index(),
|
||||
count in prop::num::u8::ANY,
|
||||
) {
|
||||
for _ in 0..count {
|
||||
index.move_up();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
450
src/merkle/merkle_tree.rs
Normal file
450
src/merkle/merkle_tree.rs
Normal file
|
@ -0,0 +1,450 @@
|
|||
use alloc::{string::String, vec::Vec};
|
||||
use core::{fmt, ops::Deref, slice};
|
||||
|
||||
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
|
||||
use crate::utils::{uninit_vector, word_to_hex};
|
||||
|
||||
// MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleTree {
|
||||
nodes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl MerkleTree {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a Merkle tree instantiated from the provided leaves.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the number of leaves is smaller than two or is not a power of two.
|
||||
pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
|
||||
where
|
||||
T: AsRef<[Word]>,
|
||||
{
|
||||
let leaves = leaves.as_ref();
|
||||
let n = leaves.len();
|
||||
if n <= 1 {
|
||||
return Err(MerkleError::DepthTooSmall(n as u8));
|
||||
} else if !n.is_power_of_two() {
|
||||
return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
|
||||
}
|
||||
|
||||
// create un-initialized vector to hold all tree nodes
|
||||
let mut nodes = unsafe { uninit_vector(2 * n) };
|
||||
nodes[0] = RpoDigest::default();
|
||||
|
||||
// copy leaves into the second part of the nodes vector
|
||||
nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
|
||||
*node = RpoDigest::from(*leaf);
|
||||
});
|
||||
|
||||
// re-interpret nodes as an array of two nodes fused together
|
||||
// Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
|
||||
// `self`).
|
||||
let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
|
||||
let pairs = unsafe { slice::from_raw_parts(ptr, n) };
|
||||
|
||||
// calculate all internal tree nodes
|
||||
for i in (1..n).rev() {
|
||||
nodes[i] = Rpo256::merge(&pairs[i]);
|
||||
}
|
||||
|
||||
Ok(Self { nodes })
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.nodes[1]
|
||||
}
|
||||
|
||||
/// Returns the depth of this Merkle tree.
|
||||
///
|
||||
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
|
||||
pub fn depth(&self) -> u8 {
|
||||
(self.nodes.len() / 2).ilog2() as u8
|
||||
}
|
||||
|
||||
/// Returns a node at the specified depth and index value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified index is not valid for the specified depth.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
Ok(self.nodes[pos])
|
||||
}
|
||||
|
||||
/// Returns a Merkle path to the node at the specified depth and index value. The node itself
|
||||
/// is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// * The specified depth is greater than the depth of the tree.
|
||||
/// * The specified value is not valid for the specified depth.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
// TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so
|
||||
// we always use inlined `for` instead of `while`? the reason to use `for` is because its
|
||||
// easier for the compiler to vectorize.
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let sibling = index.sibling().to_scalar_index() as usize;
|
||||
path.push(self.nodes[sibling]);
|
||||
index.move_up();
|
||||
}
|
||||
|
||||
debug_assert!(index.is_root(), "the path walk must go all the way to the root");
|
||||
|
||||
Ok(path.into())
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [MerkleTree].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
let leaves_start = self.nodes.len() / 2;
|
||||
self.nodes
|
||||
.iter()
|
||||
.skip(leaves_start)
|
||||
.enumerate()
|
||||
.map(|(i, v)| (i as u64, v.deref()))
|
||||
}
|
||||
|
||||
/// Returns n iterator over every inner node of this [MerkleTree].
|
||||
///
|
||||
/// The iterator order is unspecified.
|
||||
pub fn inner_nodes(&self) -> InnerNodeIterator {
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
index: 1, // index 0 is just padding, start at 1
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Replaces the leaf at the specified index with the provided value.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index value is not a valid leaf value for this tree.
|
||||
pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
|
||||
let mut index = NodeIndex::new(self.depth(), index_value)?;
|
||||
|
||||
// we don't need to copy the pairs into a new address as we are logically guaranteed to not
|
||||
// overlap write instructions. however, it's important to bind the lifetime of pairs to
|
||||
// `self.nodes` so the compiler will never move one without moving the other.
|
||||
debug_assert_eq!(self.nodes.len() & 1, 0);
|
||||
let n = self.nodes.len() / 2;
|
||||
|
||||
// Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
|
||||
// digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
|
||||
// `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
|
||||
// logically guaranteed to not overlap write positions as the write index is always half
|
||||
// the index from which we read the digest input.
|
||||
let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
|
||||
let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
|
||||
|
||||
// update the current node
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
self.nodes[pos] = value.into();
|
||||
|
||||
// traverse to the root, updating each node with the merged values of its parents
|
||||
for _ in 0..index.depth() {
|
||||
index.move_up();
|
||||
let pos = index.to_scalar_index() as usize;
|
||||
let value = Rpo256::merge(&pairs[pos]);
|
||||
self.nodes[pos] = value;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl TryFrom<&[Word]> for MerkleTree {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
|
||||
MerkleTree::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&[RpoDigest]> for MerkleTree {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(value: &[RpoDigest]) -> Result<Self, Self::Error> {
|
||||
let value: Vec<Word> = value.iter().map(|v| *v.deref()).collect();
|
||||
MerkleTree::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
/// An iterator over every inner node of the [MerkleTree].
|
||||
///
|
||||
/// Use this to extract the data of the tree, there is no guarantee on the order of the elements.
|
||||
pub struct InnerNodeIterator<'a> {
|
||||
nodes: &'a Vec<RpoDigest>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl Iterator for InnerNodeIterator<'_> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index < self.nodes.len() / 2 {
|
||||
let value = self.index;
|
||||
let left = self.index * 2;
|
||||
let right = left + 1;
|
||||
|
||||
self.index += 1;
|
||||
|
||||
Some(InnerNodeInfo {
|
||||
value: self.nodes[value],
|
||||
left: self.nodes[left],
|
||||
right: self.nodes[right],
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UTILITY FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Utility to visualize a [MerkleTree] in text.
|
||||
pub fn tree_to_text(tree: &MerkleTree) -> Result<String, fmt::Error> {
|
||||
let indent = " ";
|
||||
let mut s = String::new();
|
||||
s.push_str(&word_to_hex(&tree.root())?);
|
||||
s.push('\n');
|
||||
for d in 1..=tree.depth() {
|
||||
let entries = 2u64.pow(d.into());
|
||||
for i in 0..entries {
|
||||
let index = NodeIndex::new(d, i).expect("The index must always be valid");
|
||||
let node = tree.get_node(index).expect("The node must always be found");
|
||||
|
||||
for _ in 0..d {
|
||||
s.push_str(indent);
|
||||
}
|
||||
s.push_str(&word_to_hex(&node)?);
|
||||
s.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
/// Utility to visualize a [MerklePath] in text.
|
||||
pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
|
||||
let mut s = String::new();
|
||||
s.push('[');
|
||||
|
||||
for el in path.iter() {
|
||||
s.push_str(&word_to_hex(el)?);
|
||||
s.push_str(", ");
|
||||
}
|
||||
|
||||
// remove the last ", "
|
||||
if !path.is_empty() {
|
||||
s.pop();
|
||||
s.pop();
|
||||
}
|
||||
s.push(']');
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::mem::size_of;
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
merkle::{digests_to_words, int_to_leaf, int_to_node},
|
||||
Felt, WORD_SIZE,
|
||||
};
|
||||
|
||||
const LEAVES4: [RpoDigest; WORD_SIZE] =
|
||||
[int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const LEAVES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn build_merkle_tree() {
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
assert_eq!(8, tree.nodes.len());
|
||||
|
||||
// leaves were copied correctly
|
||||
for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
|
||||
assert_eq!(root, tree.nodes[1]);
|
||||
assert_eq!(node2, tree.nodes[2]);
|
||||
assert_eq!(node3, tree.nodes[3]);
|
||||
|
||||
assert_eq!(root, tree.root());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_leaf() {
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check depth 1
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_path() {
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
|
||||
let (_, node2, node3) = compute_internal_nodes();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check depth 1
|
||||
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let mut tree = super::MerkleTree::new(digests_to_words(&LEAVES8)).unwrap();
|
||||
|
||||
// update one leaf
|
||||
let value = 3;
|
||||
let new_node = int_to_leaf(9);
|
||||
let mut expected_leaves = digests_to_words(&LEAVES8);
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(value, new_node).unwrap();
|
||||
assert_eq!(expected_tree.nodes, tree.nodes);
|
||||
|
||||
// update another leaf
|
||||
let value = 6;
|
||||
let new_node = int_to_leaf(10);
|
||||
expected_leaves[value as usize] = new_node;
|
||||
let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
|
||||
|
||||
tree.update_leaf(value, new_node).unwrap();
|
||||
assert_eq!(expected_tree.nodes, tree.nodes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nodes() -> Result<(), MerkleError> {
|
||||
let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
|
||||
let root = tree.root();
|
||||
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
|
||||
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
|
||||
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
|
||||
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
|
||||
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
|
||||
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
|
||||
|
||||
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
|
||||
let expected = vec![
|
||||
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
|
||||
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
|
||||
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
|
||||
];
|
||||
assert_eq!(nodes, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn arbitrary_word_can_be_represented_as_digest(
|
||||
a in prop::num::u64::ANY,
|
||||
b in prop::num::u64::ANY,
|
||||
c in prop::num::u64::ANY,
|
||||
d in prop::num::u64::ANY,
|
||||
) {
|
||||
// this test will assert the memory equivalence between word and digest.
|
||||
// it is used to safeguard the `[MerkleTee::update_leaf]` implementation
|
||||
// that assumes this equivalence.
|
||||
|
||||
// build a word and copy it to another address as digest
|
||||
let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
|
||||
let digest = RpoDigest::from(word);
|
||||
|
||||
// assert the addresses are different
|
||||
let word_ptr = word.as_ptr() as *const u8;
|
||||
let digest_ptr = digest.as_ptr() as *const u8;
|
||||
assert_ne!(word_ptr, digest_ptr);
|
||||
|
||||
// compare the bytes representation
|
||||
let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
|
||||
let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
|
||||
assert_eq!(word_bytes, digest_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
|
||||
let node2 =
|
||||
Rpo256::hash_elements(&[Word::from(LEAVES4[0]), Word::from(LEAVES4[1])].concat());
|
||||
let node3 =
|
||||
Rpo256::hash_elements(&[Word::from(LEAVES4[2]), Word::from(LEAVES4[3])].concat());
|
||||
let root = Rpo256::merge(&[node2, node3]);
|
||||
|
||||
(root, node2, node3)
|
||||
}
|
||||
}
|
46
src/merkle/mmr/bit.rs
Normal file
46
src/merkle/mmr/bit.rs
Normal file
|
@ -0,0 +1,46 @@
|
|||
/// Iterate over the bits of a `usize` and yields the bit positions for the true bits.
|
||||
pub struct TrueBitPositionIterator {
|
||||
value: usize,
|
||||
}
|
||||
|
||||
impl TrueBitPositionIterator {
|
||||
pub fn new(value: usize) -> TrueBitPositionIterator {
|
||||
TrueBitPositionIterator { value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for TrueBitPositionIterator {
|
||||
type Item = u32;
|
||||
|
||||
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
|
||||
// trailing_zeros is computed with the intrinsic cttz. [Rust 1.67.0] x86 uses the `bsf`
|
||||
// instruction. AArch64 uses the `rbit clz` instructions.
|
||||
let zeros = self.value.trailing_zeros();
|
||||
|
||||
if zeros == usize::BITS {
|
||||
None
|
||||
} else {
|
||||
let bit_position = zeros;
|
||||
let mask = 1 << bit_position;
|
||||
self.value ^= mask;
|
||||
Some(bit_position)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DoubleEndedIterator for TrueBitPositionIterator {
|
||||
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
|
||||
// trailing_zeros is computed with the intrinsic ctlz. [Rust 1.67.0] x86 uses the `bsr`
|
||||
// instruction. AArch64 uses the `clz` instruction.
|
||||
let zeros = self.value.leading_zeros();
|
||||
|
||||
if zeros == usize::BITS {
|
||||
None
|
||||
} else {
|
||||
let bit_position = usize::BITS - zeros - 1;
|
||||
let mask = 1 << bit_position;
|
||||
self.value ^= mask;
|
||||
Some(bit_position)
|
||||
}
|
||||
}
|
||||
}
|
18
src/merkle/mmr/delta.rs
Normal file
18
src/merkle/mmr/delta.rs
Normal file
|
@ -0,0 +1,18 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use super::super::RpoDigest;
|
||||
|
||||
/// Container for the update data of a [super::PartialMmr]
|
||||
#[derive(Debug)]
|
||||
pub struct MmrDelta {
|
||||
/// The new version of the [super::Mmr]
|
||||
pub forest: usize,
|
||||
|
||||
/// Update data.
|
||||
///
|
||||
/// The data is packed as follows:
|
||||
/// 1. All the elements needed to perform authentication path updates. These are the right
|
||||
/// siblings required to perform tree merges on the [super::PartialMmr].
|
||||
/// 2. The new peaks.
|
||||
pub data: Vec<RpoDigest>,
|
||||
}
|
27
src/merkle/mmr/error.rs
Normal file
27
src/merkle/mmr/error.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
use alloc::string::String;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::merkle::MerkleError;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MmrError {
|
||||
#[error("mmr does not contain position {0}")]
|
||||
PositionNotFound(usize),
|
||||
#[error("mmr peaks are invalid: {0}")]
|
||||
InvalidPeaks(String),
|
||||
#[error(
|
||||
"mmr peak does not match the computed merkle root of the provided authentication path"
|
||||
)]
|
||||
PeakPathMismatch,
|
||||
#[error("requested peak index is {peak_idx} but the number of peaks is {peaks_len}")]
|
||||
PeakOutOfBounds { peak_idx: usize, peaks_len: usize },
|
||||
#[error("invalid mmr update")]
|
||||
InvalidUpdate,
|
||||
#[error("mmr does not contain a peak with depth {0}")]
|
||||
UnknownPeak(u8),
|
||||
#[error("invalid merkle path")]
|
||||
InvalidMerklePath(#[source] MerkleError),
|
||||
#[error("merkle root computation failed")]
|
||||
MerkleRootComputationFailed(#[source] MerkleError),
|
||||
}
|
447
src/merkle/mmr/full.rs
Normal file
447
src/merkle/mmr/full.rs
Normal file
|
@ -0,0 +1,447 @@
|
|||
//! A fully materialized Merkle mountain range (MMR).
|
||||
//!
|
||||
//! A MMR is a forest structure, i.e. it is an ordered set of disjoint rooted trees. The trees are
|
||||
//! ordered by size, from the most to least number of leaves. Every tree is a perfect binary tree,
|
||||
//! meaning a tree has all its leaves at the same depth, and every inner node has a branch-factor
|
||||
//! of 2 with both children set.
|
||||
//!
|
||||
//! Additionally the structure only supports adding leaves to the right-most tree, the one with the
|
||||
//! least number of leaves. The structure preserves the invariant that each tree has different
|
||||
//! depths, i.e. as part of adding a new element to the forest the trees with same depth are
|
||||
//! merged, creating a new tree with depth d+1, this process is continued until the property is
|
||||
//! reestablished.
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{
|
||||
super::{InnerNodeInfo, MerklePath},
|
||||
bit::TrueBitPositionIterator,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
|
||||
RpoDigest,
|
||||
};
|
||||
|
||||
// MMR
|
||||
// ===============================================================================================
|
||||
|
||||
/// A fully materialized Merkle Mountain Range, with every tree in the forest and all their
|
||||
/// elements.
|
||||
///
|
||||
/// Since this is a full representation of the MMR, elements are never removed and the MMR will
|
||||
/// grow roughly `O(2n)` in number of leaf elements.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct Mmr {
|
||||
/// Refer to the `forest` method documentation for details of the semantics of this value.
|
||||
pub(super) forest: usize,
|
||||
|
||||
/// Contains every element of the forest.
|
||||
///
|
||||
/// The trees are in postorder sequential representation. This representation allows for all
|
||||
/// the elements of every tree in the forest to be stored in the same sequential buffer. It
|
||||
/// also means new elements can be added to the forest, and merging of trees is very cheap with
|
||||
/// no need to copy elements.
|
||||
pub(super) nodes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl Default for Mmr {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Mmr {
|
||||
// CONSTRUCTORS
|
||||
// ============================================================================================
|
||||
|
||||
/// Constructor for an empty `Mmr`.
|
||||
pub fn new() -> Mmr {
|
||||
Mmr { forest: 0, nodes: Vec::new() }
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// ============================================================================================
|
||||
|
||||
/// Returns the MMR forest representation.
|
||||
///
|
||||
/// The forest value has the following interpretations:
|
||||
/// - its value is the number of elements in the forest
|
||||
/// - bit count corresponds to the number of trees in the forest
|
||||
/// - each true bit position determines the depth of a tree in the forest
|
||||
pub const fn forest(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
// FUNCTIONALITY
|
||||
// ============================================================================================
|
||||
|
||||
/// Returns an [MmrProof] for the leaf at the specified position.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified leaf position is out of bounds for this MMR.
|
||||
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
|
||||
self.open_at(pos, self.forest)
|
||||
}
|
||||
|
||||
/// Returns an [MmrProof] for the leaf at the specified position using the state of the MMR
|
||||
/// at the specified `forest`.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The specified leaf position is out of bounds for this MMR.
|
||||
/// - The specified `forest` value is not valid for this MMR.
|
||||
pub fn open_at(&self, pos: usize, forest: usize) -> Result<MmrProof, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, forest).ok_or(MmrError::PositionNotFound(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// update the value position from global to the target tree
|
||||
let relative_pos = pos - forest_before;
|
||||
|
||||
// collect the path and the final index of the target value
|
||||
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(MmrProof {
|
||||
forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(path),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the leaf value at position `pos`.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> {
|
||||
// find the target tree responsible for the MMR position
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
|
||||
|
||||
// isolate the trees before the target
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
let index_offset = nodes_in_forest(forest_before);
|
||||
|
||||
// update the value position from global to the target tree
|
||||
let relative_pos = pos - forest_before;
|
||||
|
||||
// collect the path and the final index of the target value
|
||||
let (value, _) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
/// Adds a new element to the MMR.
|
||||
pub fn add(&mut self, el: RpoDigest) {
|
||||
// Note: every node is also a tree of size 1, adding an element to the forest creates a new
|
||||
// rooted-tree of size 1. This may temporarily break the invariant that every tree in the
|
||||
// forest has different sizes, the loop below will eagerly merge trees of same size and
|
||||
// restore the invariant.
|
||||
self.nodes.push(el);
|
||||
|
||||
let mut left_offset = self.nodes.len().saturating_sub(2);
|
||||
let mut right = el;
|
||||
let mut left_tree = 1;
|
||||
while self.forest & left_tree != 0 {
|
||||
right = Rpo256::merge(&[self.nodes[left_offset], right]);
|
||||
self.nodes.push(right);
|
||||
|
||||
left_offset = left_offset.saturating_sub(nodes_in_forest(left_tree));
|
||||
left_tree <<= 1;
|
||||
}
|
||||
|
||||
self.forest += 1;
|
||||
}
|
||||
|
||||
/// Returns the current peaks of the MMR.
|
||||
pub fn peaks(&self) -> MmrPeaks {
|
||||
self.peaks_at(self.forest).expect("failed to get peaks at current forest")
|
||||
}
|
||||
|
||||
/// Returns the peaks of the MMR at the state specified by `forest`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified `forest` value is not valid for this MMR.
|
||||
pub fn peaks_at(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
|
||||
if forest > self.forest {
|
||||
return Err(MmrError::InvalidPeaks(format!(
|
||||
"requested forest {forest} exceeds current forest {}",
|
||||
self.forest
|
||||
)));
|
||||
}
|
||||
|
||||
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
|
||||
.rev()
|
||||
.map(|bit| nodes_in_forest(1 << bit))
|
||||
.scan(0, |offset, el| {
|
||||
*offset += el;
|
||||
Some(*offset)
|
||||
})
|
||||
.map(|offset| self.nodes[offset - 1])
|
||||
.collect();
|
||||
|
||||
// Safety: the invariant is maintained by the [Mmr]
|
||||
let peaks = MmrPeaks::new(forest, peaks).unwrap();
|
||||
|
||||
Ok(peaks)
|
||||
}
|
||||
|
||||
/// Compute the required update to `original_forest`.
|
||||
///
|
||||
/// The result is a packed sequence of the authentication elements required to update the trees
|
||||
/// that have been merged together, followed by the new peaks of the [Mmr].
|
||||
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
|
||||
if to_forest > self.forest || from_forest > to_forest {
|
||||
return Err(MmrError::InvalidPeaks(format!("to_forest {to_forest} exceeds the current forest {} or from_forest {from_forest} exceeds to_forest", self.forest)));
|
||||
}
|
||||
|
||||
if from_forest == to_forest {
|
||||
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Find the largest tree in this [Mmr] which is new to `from_forest`.
|
||||
let candidate_trees = to_forest ^ from_forest;
|
||||
let mut new_high = 1 << candidate_trees.ilog2();
|
||||
|
||||
// Collect authentication nodes used for tree merges
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// Find the trees from `from_forest` that have been merged into `new_high`.
|
||||
let mut merges = from_forest & (new_high - 1);
|
||||
|
||||
// Find the peaks that are common to `from_forest` and this [Mmr]
|
||||
let common_trees = from_forest ^ merges;
|
||||
|
||||
if merges != 0 {
|
||||
// Skip the smallest trees unknown to `from_forest`.
|
||||
let mut target = 1 << merges.trailing_zeros();
|
||||
|
||||
// Collect siblings required to computed the merged tree's peak
|
||||
while target < new_high {
|
||||
// Computes the offset to the smallest know peak
|
||||
// - common_trees: peaks unchanged in the current update, target comes after these.
|
||||
// - merges: peaks that have not been merged so far, target comes after these.
|
||||
// - target: tree from which to load the sibling. On the first iteration this is a
|
||||
// value known by the partial mmr, on subsequent iterations this value is to be
|
||||
// computed from the known peaks and provided authentication nodes.
|
||||
let known = nodes_in_forest(common_trees | merges | target);
|
||||
let sibling = nodes_in_forest(target);
|
||||
result.push(self.nodes[known + sibling - 1]);
|
||||
|
||||
// Update the target and account for tree merges
|
||||
target <<= 1;
|
||||
while merges & target != 0 {
|
||||
target <<= 1;
|
||||
}
|
||||
// Remove the merges done so far
|
||||
merges ^= merges & (target - 1);
|
||||
}
|
||||
} else {
|
||||
// The new high tree may not be the result of any merges, if it is smaller than all the
|
||||
// trees of `from_forest`.
|
||||
new_high = 0;
|
||||
}
|
||||
|
||||
// Collect the new [Mmr] peaks
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let mut new_peaks = to_forest ^ common_trees ^ new_high;
|
||||
let old_peaks = to_forest ^ new_peaks;
|
||||
let mut offset = nodes_in_forest(old_peaks);
|
||||
while new_peaks != 0 {
|
||||
let target = 1 << new_peaks.ilog2();
|
||||
offset += nodes_in_forest(target);
|
||||
result.push(self.nodes[offset - 1]);
|
||||
new_peaks ^= target;
|
||||
}
|
||||
|
||||
Ok(MmrDelta { forest: to_forest, data: result })
|
||||
}
|
||||
|
||||
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
|
||||
pub fn inner_nodes(&self) -> MmrNodes {
|
||||
MmrNodes {
|
||||
mmr: self,
|
||||
forest: 0,
|
||||
last_right: 0,
|
||||
index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// UTILITIES
|
||||
// ============================================================================================
|
||||
|
||||
/// Internal function used to collect the Merkle path of a value.
|
||||
///
|
||||
/// The arguments are relative to the target tree. To compute the opening of the second leaf
|
||||
/// for a tree with depth 2 in the forest `0b110`:
|
||||
///
|
||||
/// - `tree_bit`: Depth of the target tree, e.g. 2 for the smallest tree.
|
||||
/// - `relative_pos`: 0-indexed leaf position in the target tree, e.g. 1 for the second leaf.
|
||||
/// - `index_offset`: Node count prior to the target tree, e.g. 7 for the tree of depth 3.
|
||||
fn collect_merkle_path_and_value(
|
||||
&self,
|
||||
tree_bit: u32,
|
||||
relative_pos: usize,
|
||||
index_offset: usize,
|
||||
) -> (RpoDigest, Vec<RpoDigest>) {
|
||||
// see documentation of `leaf_to_corresponding_tree` for details
|
||||
let tree_depth = (tree_bit + 1) as usize;
|
||||
let mut path = Vec::with_capacity(tree_depth);
|
||||
|
||||
// The tree walk below goes from the root to the leaf, compute the root index to start
|
||||
let mut forest_target = 1usize << tree_bit;
|
||||
let mut index = nodes_in_forest(forest_target) - 1;
|
||||
|
||||
// Loop until the leaf is reached
|
||||
while forest_target > 1 {
|
||||
// Update the depth of the tree to correspond to a subtree
|
||||
forest_target >>= 1;
|
||||
|
||||
// compute the indices of the right and left subtrees based on the post-order
|
||||
let right_offset = index - 1;
|
||||
let left_offset = right_offset - nodes_in_forest(forest_target);
|
||||
|
||||
let left_or_right = relative_pos & forest_target;
|
||||
let sibling = if left_or_right != 0 {
|
||||
// going down the right subtree, the right child becomes the new root
|
||||
index = right_offset;
|
||||
// and the left child is the authentication
|
||||
self.nodes[index_offset + left_offset]
|
||||
} else {
|
||||
index = left_offset;
|
||||
self.nodes[index_offset + right_offset]
|
||||
};
|
||||
|
||||
path.push(sibling);
|
||||
}
|
||||
|
||||
debug_assert!(path.len() == tree_depth - 1);
|
||||
|
||||
// the rest of the codebase has the elements going from leaf to root, adjust it here for
|
||||
// easy of use/consistency sake
|
||||
path.reverse();
|
||||
|
||||
let value = self.nodes[index_offset + index];
|
||||
(value, path)
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl<T> From<T> for Mmr
|
||||
where
|
||||
T: IntoIterator<Item = RpoDigest>,
|
||||
{
|
||||
fn from(values: T) -> Self {
|
||||
let mut mmr = Mmr::new();
|
||||
for v in values {
|
||||
mmr.add(v)
|
||||
}
|
||||
mmr
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATOR
|
||||
// ===============================================================================================
|
||||
|
||||
/// Yields inner nodes of the [Mmr].
|
||||
pub struct MmrNodes<'a> {
|
||||
/// [Mmr] being yielded, when its `forest` value is matched, the iterations is finished.
|
||||
mmr: &'a Mmr,
|
||||
/// Keeps track of the left nodes yielded so far waiting for a right pair, this matches the
|
||||
/// semantics of the [Mmr]'s forest attribute, since that too works as a buffer of left nodes
|
||||
/// waiting for a pair to be hashed together.
|
||||
forest: usize,
|
||||
/// Keeps track of the last right node yielded, after this value is set, the next iteration
|
||||
/// will be its parent with its corresponding left node that has been yield already.
|
||||
last_right: usize,
|
||||
/// The current index in the `nodes` vector.
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl Iterator for MmrNodes<'_> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
debug_assert!(self.last_right.count_ones() <= 1, "last_right tracks zero or one element");
|
||||
|
||||
// only parent nodes are emitted, remove the single node tree from the forest
|
||||
let target = self.mmr.forest & (usize::MAX << 1);
|
||||
|
||||
if self.forest < target {
|
||||
if self.last_right == 0 {
|
||||
// yield the left leaf
|
||||
debug_assert!(self.last_right == 0, "left must be before right");
|
||||
self.forest |= 1;
|
||||
self.index += 1;
|
||||
|
||||
// yield the right leaf
|
||||
debug_assert!((self.forest & 1) == 1, "right must be after left");
|
||||
self.last_right |= 1;
|
||||
self.index += 1;
|
||||
};
|
||||
|
||||
debug_assert!(
|
||||
self.forest & self.last_right != 0,
|
||||
"parent requires both a left and right",
|
||||
);
|
||||
|
||||
// compute the number of nodes in the right tree, this is the offset to the
|
||||
// previous left parent
|
||||
let right_nodes = nodes_in_forest(self.last_right);
|
||||
// the next parent position is one above the position of the pair
|
||||
let parent = self.last_right << 1;
|
||||
|
||||
// the left node has been paired and the current parent yielded, removed it from the
|
||||
// forest
|
||||
self.forest ^= self.last_right;
|
||||
if self.forest & parent == 0 {
|
||||
// this iteration yielded the left parent node
|
||||
debug_assert!(self.forest & 1 == 0, "next iteration yields a left leaf");
|
||||
self.last_right = 0;
|
||||
self.forest ^= parent;
|
||||
} else {
|
||||
// the left node of the parent level has been yielded already, this iteration
|
||||
// was the right parent. Next iteration yields their parent.
|
||||
self.last_right = parent;
|
||||
}
|
||||
|
||||
// yields a parent
|
||||
let value = self.mmr.nodes[self.index];
|
||||
let right = self.mmr.nodes[self.index - 1];
|
||||
let left = self.mmr.nodes[self.index - 1 - right_nodes];
|
||||
self.index += 1;
|
||||
let node = InnerNodeInfo { value, left, right };
|
||||
|
||||
Some(node)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UTILITIES
|
||||
// ===============================================================================================
|
||||
|
||||
/// Return a bitmask for the bits including and above the given position.
|
||||
pub(crate) const fn high_bitmask(bit: u32) -> usize {
|
||||
if bit > usize::BITS - 1 {
|
||||
0
|
||||
} else {
|
||||
usize::MAX << bit
|
||||
}
|
||||
}
|
191
src/merkle/mmr/inorder.rs
Normal file
191
src/merkle/mmr/inorder.rs
Normal file
|
@ -0,0 +1,191 @@
|
|||
//! Index for nodes of a binary tree based on an in-order tree walk.
|
||||
//!
|
||||
//! In-order walks have the parent node index split its left and right subtrees. All the left
|
||||
//! children have indexes lower than the parent, meanwhile all the right subtree higher indexes.
|
||||
//! This property makes it is easy to compute changes to the index by adding or subtracting the
|
||||
//! leaves count.
|
||||
use core::num::NonZeroUsize;
|
||||
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
// IN-ORDER INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct InOrderIndex {
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl InOrderIndex {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [InOrderIndex] instantiated from the provided value.
|
||||
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
|
||||
InOrderIndex { idx: idx.get() }
|
||||
}
|
||||
|
||||
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
|
||||
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
|
||||
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
|
||||
// implementation only works 1-indexed counting.
|
||||
let pos = leaf + 1;
|
||||
InOrderIndex { idx: pos * 2 - 1 }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// True if the index is pointing at a leaf.
|
||||
///
|
||||
/// Every odd number represents a leaf.
|
||||
pub fn is_leaf(&self) -> bool {
|
||||
self.idx & 1 == 1
|
||||
}
|
||||
|
||||
/// Returns true if this note is a left child of its parent.
|
||||
pub fn is_left_child(&self) -> bool {
|
||||
self.parent().left_child() == *self
|
||||
}
|
||||
|
||||
/// Returns the level of the index.
|
||||
///
|
||||
/// Starts at level zero for leaves and increases by one for each parent.
|
||||
pub fn level(&self) -> u32 {
|
||||
self.idx.trailing_zeros()
|
||||
}
|
||||
|
||||
/// Returns the index of the left child.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn left_child(&self) -> InOrderIndex {
|
||||
// The left child is itself a parent, with an index that splits its left/right subtrees. To
|
||||
// go from the parent index to its left child, it is only necessary to subtract the count
|
||||
// of elements on the child's right subtree + 1.
|
||||
let els = 1 << (self.level() - 1);
|
||||
InOrderIndex { idx: self.idx - els }
|
||||
}
|
||||
|
||||
/// Returns the index of the right child.
|
||||
///
|
||||
/// # Panics:
|
||||
/// If the index corresponds to a leaf.
|
||||
pub fn right_child(&self) -> InOrderIndex {
|
||||
// To compute the index of the parent of the right subtree it is sufficient to add the size
|
||||
// of its left subtree + 1.
|
||||
let els = 1 << (self.level() - 1);
|
||||
InOrderIndex { idx: self.idx + els }
|
||||
}
|
||||
|
||||
/// Returns the index of the parent node.
|
||||
pub fn parent(&self) -> InOrderIndex {
|
||||
// If the current index corresponds to a node in a left tree, to go up a level it is
|
||||
// required to add the number of nodes of the right sibling, analogously if the node is a
|
||||
// right child, going up requires subtracting the number of nodes in its left subtree.
|
||||
//
|
||||
// Both of the above operations can be performed by bitwise manipulation. Below the mask
|
||||
// sets the number of trailing zeros to be equal the new level of the index, and the bit
|
||||
// marks the parent.
|
||||
let target = self.level() + 1;
|
||||
let bit = 1 << target;
|
||||
let mask = bit - 1;
|
||||
let idx = self.idx ^ (self.idx & mask);
|
||||
InOrderIndex { idx: idx | bit }
|
||||
}
|
||||
|
||||
/// Returns the index of the sibling node.
|
||||
pub fn sibling(&self) -> InOrderIndex {
|
||||
let parent = self.parent();
|
||||
if *self > parent {
|
||||
parent.left_child()
|
||||
} else {
|
||||
parent.right_child()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the inner value of this [InOrderIndex].
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for InOrderIndex {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
target.write_usize(self.idx);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for InOrderIndex {
|
||||
fn read_from<R: winter_utils::ByteReader>(
|
||||
source: &mut R,
|
||||
) -> Result<Self, winter_utils::DeserializationError> {
|
||||
let idx = source.read_usize()?;
|
||||
Ok(InOrderIndex { idx })
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS FROM IN-ORDER INDEX
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl From<InOrderIndex> for u64 {
|
||||
fn from(index: InOrderIndex) -> Self {
|
||||
index.idx as u64
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use proptest::prelude::*;
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::InOrderIndex;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn proptest_inorder_index_random(count in 1..1000usize) {
|
||||
let left_pos = count * 2;
|
||||
let right_pos = count * 2 + 1;
|
||||
|
||||
let left = InOrderIndex::from_leaf_pos(left_pos);
|
||||
let right = InOrderIndex::from_leaf_pos(right_pos);
|
||||
|
||||
assert!(left.is_leaf());
|
||||
assert!(right.is_leaf());
|
||||
assert_eq!(left.parent(), right.parent());
|
||||
assert_eq!(left.parent().right_child(), right);
|
||||
assert_eq!(left, right.parent().left_child());
|
||||
assert_eq!(left.sibling(), right);
|
||||
assert_eq!(left, right.sibling());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inorder_index_basic() {
|
||||
let left = InOrderIndex::from_leaf_pos(0);
|
||||
let right = InOrderIndex::from_leaf_pos(1);
|
||||
|
||||
assert!(left.is_leaf());
|
||||
assert!(right.is_leaf());
|
||||
assert_eq!(left.parent(), right.parent());
|
||||
assert_eq!(left.parent().right_child(), right);
|
||||
assert_eq!(left, right.parent().left_child());
|
||||
assert_eq!(left.sibling(), right);
|
||||
assert_eq!(left, right.sibling());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inorder_index_serialization() {
|
||||
let index = InOrderIndex::from_leaf_pos(5);
|
||||
let bytes = index.to_bytes();
|
||||
let index2 = InOrderIndex::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(index, index2);
|
||||
}
|
||||
}
|
67
src/merkle/mmr/mod.rs
Normal file
67
src/merkle/mmr/mod.rs
Normal file
|
@ -0,0 +1,67 @@
|
|||
mod bit;
|
||||
mod delta;
|
||||
mod error;
|
||||
mod full;
|
||||
mod inorder;
|
||||
mod partial;
|
||||
mod peaks;
|
||||
mod proof;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
pub use delta::MmrDelta;
|
||||
pub use error::MmrError;
|
||||
pub use full::Mmr;
|
||||
pub use inorder::InOrderIndex;
|
||||
pub use partial::PartialMmr;
|
||||
pub use peaks::MmrPeaks;
|
||||
pub use proof::MmrProof;
|
||||
|
||||
use super::{Felt, Rpo256, RpoDigest, Word};
|
||||
|
||||
// UTILITIES
|
||||
// ===============================================================================================
|
||||
|
||||
/// Given a 0-indexed leaf position and the current forest, return the tree number responsible for
|
||||
/// the position.
|
||||
///
|
||||
/// Note:
|
||||
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
|
||||
/// the tree. Because the root element is not part of the proof, $p$ is the length of the
|
||||
/// authentication path. $2^p$ is equal to the number of leaves in this particular tree. and
|
||||
/// $2^(p+1)-1$ corresponds to size of the tree.
|
||||
const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
|
||||
if pos >= forest {
|
||||
None
|
||||
} else {
|
||||
// - each bit in the forest is a unique tree and the bit position its power-of-two size
|
||||
// - each tree owns a consecutive range of positions equal to its size from left-to-right
|
||||
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
|
||||
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
|
||||
// `k_1` is the second highest bit, so on.
|
||||
// - this means the highest bits work as a category marker, and the position is owned by the
|
||||
// first tree which doesn't share a high bit with the position
|
||||
let before = forest & pos;
|
||||
let after = forest ^ before;
|
||||
let tree = after.ilog2();
|
||||
|
||||
Some(tree)
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the total number of nodes of a given forest
|
||||
///
|
||||
/// Panics:
|
||||
///
|
||||
/// This will panic if the forest has size greater than `usize::MAX / 2`
|
||||
const fn nodes_in_forest(forest: usize) -> usize {
|
||||
// - the size of a perfect binary tree is $2^{k+1}-1$ or $2*2^k-1$
|
||||
// - the forest represents the sum of $2^k$ so a single multiplication is necessary
|
||||
// - the number of `-1` is the same as the number of trees, which is the same as the number
|
||||
// bits set
|
||||
let tree_count = forest.count_ones() as usize;
|
||||
forest * 2 - tree_count
|
||||
}
|
952
src/merkle/mmr/partial.rs
Normal file
952
src/merkle/mmr/partial.rs
Normal file
|
@ -0,0 +1,952 @@
|
|||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
|
||||
use crate::merkle::{
|
||||
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
|
||||
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
|
||||
};
|
||||
|
||||
// TYPE ALIASES
|
||||
// ================================================================================================
|
||||
|
||||
type NodeMap = BTreeMap<InOrderIndex, RpoDigest>;
|
||||
|
||||
// PARTIAL MERKLE MOUNTAIN RANGE
|
||||
// ================================================================================================
|
||||
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
|
||||
/// authentication paths for a subset of the elements in a full MMR.
|
||||
///
|
||||
/// This structure store only the authentication path for a value, the value itself is stored
|
||||
/// separately.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PartialMmr {
|
||||
/// The version of the MMR.
|
||||
///
|
||||
/// This value serves the following purposes:
|
||||
///
|
||||
/// - The forest is a counter for the total number of elements in the MMR.
|
||||
/// - Since the MMR is an append-only structure, every change to it causes a change to the
|
||||
/// `forest`, so this value has a dual purpose as a version tag.
|
||||
/// - The bits in the forest also corresponds to the count and size of every perfect binary
|
||||
/// tree that composes the MMR structure, which server to compute indexes and perform
|
||||
/// validation.
|
||||
pub(crate) forest: usize,
|
||||
|
||||
/// The MMR peaks.
|
||||
///
|
||||
/// The peaks are used for two reasons:
|
||||
///
|
||||
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
|
||||
/// elements are tracked.
|
||||
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
|
||||
/// peaks are used as the left hand.
|
||||
///
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
pub(crate) peaks: Vec<RpoDigest>,
|
||||
|
||||
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
|
||||
///
|
||||
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required to
|
||||
/// construct their authentication paths. This property is used to detect when elements can be
|
||||
/// safely removed, because they are no longer required to authenticate any element in the
|
||||
/// [PartialMmr].
|
||||
///
|
||||
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
|
||||
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
|
||||
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
|
||||
/// trees in the MMR can be represented without rewrites of the indexes.
|
||||
pub(crate) nodes: NodeMap,
|
||||
|
||||
/// Flag indicating if the odd element should be tracked.
|
||||
///
|
||||
/// This flag is necessary because the sibling of the odd doesn't exist yet, so it can not be
|
||||
/// added into `nodes` to signal the value is being tracked.
|
||||
pub(crate) track_latest: bool,
|
||||
}
|
||||
|
||||
impl PartialMmr {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [PartialMmr] instantiated from the specified peaks.
|
||||
pub fn from_peaks(peaks: MmrPeaks) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.into();
|
||||
let nodes = BTreeMap::new();
|
||||
let track_latest = false;
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
/// Returns a new [PartialMmr] instantiated from the specified components.
|
||||
///
|
||||
/// This constructor does not check the consistency between peaks and nodes. If the specified
|
||||
/// peaks are nodes are inconsistent, the returned partial MMR may exhibit undefined behavior.
|
||||
pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
|
||||
let forest = peaks.num_leaves();
|
||||
let peaks = peaks.into();
|
||||
|
||||
Self { forest, peaks, nodes, track_latest }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the current `forest` of this [PartialMmr].
|
||||
///
|
||||
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
|
||||
/// underlying MMR.
|
||||
pub fn forest(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.forest
|
||||
}
|
||||
|
||||
/// Returns the peaks of the MMR for this [PartialMmr].
|
||||
pub fn peaks(&self) -> MmrPeaks {
|
||||
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
|
||||
// correctly
|
||||
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
|
||||
}
|
||||
|
||||
/// Returns true if this partial MMR tracks an authentication path for the leaf at the
|
||||
/// specified position.
|
||||
pub fn is_tracked(&self, pos: usize) -> bool {
|
||||
if pos >= self.forest {
|
||||
return false;
|
||||
} else if pos == self.forest - 1 && self.forest & 1 != 0 {
|
||||
// if the number of leaves in the MMR is odd and the position is for the last leaf
|
||||
// whether the leaf is tracked is defined by the `track_latest` flag
|
||||
return self.track_latest;
|
||||
}
|
||||
|
||||
let leaf_index = InOrderIndex::from_leaf_pos(pos);
|
||||
self.is_tracked_node(&leaf_index)
|
||||
}
|
||||
|
||||
/// Given a leaf position, returns the Merkle path to its corresponding peak, or None if this
|
||||
/// partial MMR does not track an authentication paths for the specified leaf.
|
||||
///
|
||||
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
|
||||
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
|
||||
/// has position 0, the second position 1, and so on.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified position is greater-or-equal than the number of leaves
|
||||
/// in the underlying MMR.
|
||||
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
|
||||
let tree_bit =
|
||||
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
|
||||
let depth = tree_bit as usize;
|
||||
|
||||
let mut nodes = Vec::with_capacity(depth);
|
||||
let mut idx = InOrderIndex::from_leaf_pos(pos);
|
||||
|
||||
while let Some(node) = self.nodes.get(&idx.sibling()) {
|
||||
nodes.push(*node);
|
||||
idx = idx.parent();
|
||||
}
|
||||
|
||||
// If there are nodes then the path must be complete, otherwise it is a bug
|
||||
debug_assert!(nodes.is_empty() || nodes.len() == depth);
|
||||
|
||||
if nodes.len() != depth {
|
||||
// The requested `pos` is not being tracked.
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(MmrProof {
|
||||
forest: self.forest,
|
||||
position: pos,
|
||||
merkle_path: MerklePath::new(nodes),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
|
||||
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
|
||||
self.nodes.iter()
|
||||
}
|
||||
|
||||
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
|
||||
///
|
||||
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
|
||||
/// is silently ignored.
|
||||
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
|
||||
&'a self,
|
||||
mut leaves: I,
|
||||
) -> impl Iterator<Item = InnerNodeInfo> + 'a {
|
||||
let stack = if let Some((pos, leaf)) = leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
vec![(idx, leaf)]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
leaves,
|
||||
stack,
|
||||
seen_nodes: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds a new peak and optionally track it. Returns a vector of the authentication nodes
|
||||
/// inserted into this [PartialMmr] as a result of this operation.
|
||||
///
|
||||
/// When `track` is `true` the new leaf is tracked.
|
||||
pub fn add(&mut self, leaf: RpoDigest, track: bool) -> Vec<(InOrderIndex, RpoDigest)> {
|
||||
self.forest += 1;
|
||||
let merges = self.forest.trailing_zeros() as usize;
|
||||
let mut new_nodes = Vec::with_capacity(merges);
|
||||
|
||||
let peak = if merges == 0 {
|
||||
self.track_latest = track;
|
||||
leaf
|
||||
} else {
|
||||
let mut track_right = track;
|
||||
let mut track_left = self.track_latest;
|
||||
|
||||
let mut right = leaf;
|
||||
let mut right_idx = forest_to_rightmost_index(self.forest);
|
||||
|
||||
for _ in 0..merges {
|
||||
let left = self.peaks.pop().expect("Missing peak");
|
||||
let left_idx = right_idx.sibling();
|
||||
|
||||
if track_right {
|
||||
let old = self.nodes.insert(left_idx, left);
|
||||
new_nodes.push((left_idx, left));
|
||||
|
||||
debug_assert!(
|
||||
old.is_none(),
|
||||
"Idx {:?} already contained an element {:?}",
|
||||
left_idx,
|
||||
old
|
||||
);
|
||||
};
|
||||
if track_left {
|
||||
let old = self.nodes.insert(right_idx, right);
|
||||
new_nodes.push((right_idx, right));
|
||||
|
||||
debug_assert!(
|
||||
old.is_none(),
|
||||
"Idx {:?} already contained an element {:?}",
|
||||
right_idx,
|
||||
old
|
||||
);
|
||||
};
|
||||
|
||||
// Update state for the next iteration.
|
||||
// --------------------------------------------------------------------------------
|
||||
|
||||
// This layer is merged, go up one layer.
|
||||
right_idx = right_idx.parent();
|
||||
|
||||
// Merge the current layer. The result is either the right element of the next
|
||||
// merge, or a new peak.
|
||||
right = Rpo256::merge(&[left, right]);
|
||||
|
||||
// This iteration merged the left and right nodes, the new value is always used as
|
||||
// the next iteration's right node. Therefore the tracking flags of this iteration
|
||||
// have to be merged into the right side only.
|
||||
track_right = track_right || track_left;
|
||||
|
||||
// On the next iteration, a peak will be merged. If any of its children are tracked,
|
||||
// then we have to track the left side
|
||||
track_left = self.is_tracked_node(&right_idx.sibling());
|
||||
}
|
||||
right
|
||||
};
|
||||
|
||||
self.peaks.push(peak);
|
||||
|
||||
new_nodes
|
||||
}
|
||||
|
||||
/// Adds the authentication path represented by [MerklePath] if it is valid.
|
||||
///
|
||||
/// The `leaf_pos` refers to the global position of the leaf in the MMR, these are 0-indexed
|
||||
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
|
||||
/// this value corresponds to the values used in the MMR structure.
|
||||
///
|
||||
/// The `leaf` corresponds to the value at `leaf_pos`, and `path` is the authentication path for
|
||||
/// that element up to its corresponding Mmr peak. The `leaf` is only used to compute the root
|
||||
/// from the authentication path to valid the data, only the authentication data is saved in
|
||||
/// the structure. If the value is required it should be stored out-of-band.
|
||||
pub fn track(
|
||||
&mut self,
|
||||
leaf_pos: usize,
|
||||
leaf: RpoDigest,
|
||||
path: &MerklePath,
|
||||
) -> Result<(), MmrError> {
|
||||
// Checks there is a tree with same depth as the authentication path, if not the path is
|
||||
// invalid.
|
||||
let tree = 1 << path.depth();
|
||||
if tree & self.forest == 0 {
|
||||
return Err(MmrError::UnknownPeak(path.depth()));
|
||||
};
|
||||
|
||||
if leaf_pos + 1 == self.forest
|
||||
&& path.depth() == 0
|
||||
&& self.peaks.last().is_some_and(|v| *v == leaf)
|
||||
{
|
||||
self.track_latest = true;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// ignore the trees smaller than the target (these elements are position after the current
|
||||
// target and don't affect the target leaf_pos)
|
||||
let target_forest = self.forest ^ (self.forest & (tree - 1));
|
||||
let peak_pos = (target_forest.count_ones() - 1) as usize;
|
||||
|
||||
// translate from mmr leaf_pos to merkle path
|
||||
let path_idx = leaf_pos - (target_forest ^ tree);
|
||||
|
||||
// Compute the root of the authentication path, and check it matches the current version of
|
||||
// the PartialMmr.
|
||||
let computed = path
|
||||
.compute_root(path_idx as u64, leaf)
|
||||
.map_err(MmrError::MerkleRootComputationFailed)?;
|
||||
if self.peaks[peak_pos] != computed {
|
||||
return Err(MmrError::PeakPathMismatch);
|
||||
}
|
||||
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
for leaf in path.nodes() {
|
||||
self.nodes.insert(idx.sibling(), *leaf);
|
||||
idx = idx.parent();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a leaf of the [PartialMmr] and the unused nodes from the authentication path.
|
||||
///
|
||||
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
|
||||
pub fn untrack(&mut self, leaf_pos: usize) {
|
||||
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
|
||||
|
||||
self.nodes.remove(&idx.sibling());
|
||||
|
||||
// `idx` represent the element that can be computed by the authentication path, because
|
||||
// these elements can be computed they are not saved for the authentication of the current
|
||||
// target. In other words, if the idx is present it was added for the authentication of
|
||||
// another element, and no more elements should be removed otherwise it would remove that
|
||||
// element's authentication data.
|
||||
while !self.nodes.contains_key(&idx) {
|
||||
idx = idx.parent();
|
||||
self.nodes.remove(&idx.sibling());
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
|
||||
/// inserted into the partial MMR.
|
||||
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
|
||||
if delta.forest < self.forest {
|
||||
return Err(MmrError::InvalidPeaks(format!(
|
||||
"forest of mmr delta {} is less than current forest {}",
|
||||
delta.forest, self.forest
|
||||
)));
|
||||
}
|
||||
|
||||
let mut inserted_nodes = Vec::new();
|
||||
|
||||
if delta.forest == self.forest {
|
||||
if !delta.data.is_empty() {
|
||||
return Err(MmrError::InvalidUpdate);
|
||||
}
|
||||
|
||||
return Ok(inserted_nodes);
|
||||
}
|
||||
|
||||
// find the tree merges
|
||||
let changes = self.forest ^ delta.forest;
|
||||
let largest = 1 << changes.ilog2();
|
||||
let merges = self.forest & (largest - 1);
|
||||
|
||||
debug_assert!(
|
||||
!self.track_latest || (merges & 1) == 1,
|
||||
"if there is an odd element, a merge is required"
|
||||
);
|
||||
|
||||
// count the number elements needed to produce largest from the current state
|
||||
let (merge_count, new_peaks) = if merges != 0 {
|
||||
let depth = largest.trailing_zeros();
|
||||
let skipped = merges.trailing_zeros();
|
||||
let computed = merges.count_ones() - 1;
|
||||
let merge_count = depth - skipped - computed;
|
||||
|
||||
let new_peaks = delta.forest & (largest - 1);
|
||||
|
||||
(merge_count, new_peaks)
|
||||
} else {
|
||||
(0, changes)
|
||||
};
|
||||
|
||||
// verify the delta size
|
||||
if (delta.data.len() as u32) != merge_count + new_peaks.count_ones() {
|
||||
return Err(MmrError::InvalidUpdate);
|
||||
}
|
||||
|
||||
// keeps track of how many data elements from the update have been consumed
|
||||
let mut update_count = 0;
|
||||
|
||||
if merges != 0 {
|
||||
// starts at the smallest peak and follows the merged peaks
|
||||
let mut peak_idx = forest_to_root_index(self.forest);
|
||||
|
||||
// match order of the update data while applying it
|
||||
self.peaks.reverse();
|
||||
|
||||
// set to true when the data is needed for authentication paths updates
|
||||
let mut track = self.track_latest;
|
||||
self.track_latest = false;
|
||||
|
||||
let mut peak_count = 0;
|
||||
let mut target = 1 << merges.trailing_zeros();
|
||||
let mut new = delta.data[0];
|
||||
update_count += 1;
|
||||
|
||||
while target < largest {
|
||||
// check if either the left or right subtrees have saved for authentication paths.
|
||||
// If so, turn tracking on to update those paths.
|
||||
if target != 1 && !track {
|
||||
track = self.is_tracked_node(&peak_idx);
|
||||
}
|
||||
|
||||
// update data only contains the nodes from the right subtrees, left nodes are
|
||||
// either previously known peaks or computed values
|
||||
let (left, right) = if target & merges != 0 {
|
||||
let peak = self.peaks[peak_count];
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
|
||||
// if the sibling peak is tracked, add this peaks to the set of
|
||||
// authentication nodes
|
||||
if self.is_tracked_node(&sibling_idx) {
|
||||
self.nodes.insert(peak_idx, new);
|
||||
inserted_nodes.push((peak_idx, new));
|
||||
}
|
||||
peak_count += 1;
|
||||
(peak, new)
|
||||
} else {
|
||||
let update = delta.data[update_count];
|
||||
update_count += 1;
|
||||
(new, update)
|
||||
};
|
||||
|
||||
if track {
|
||||
let sibling_idx = peak_idx.sibling();
|
||||
if peak_idx.is_left_child() {
|
||||
self.nodes.insert(sibling_idx, right);
|
||||
inserted_nodes.push((sibling_idx, right));
|
||||
} else {
|
||||
self.nodes.insert(sibling_idx, left);
|
||||
inserted_nodes.push((sibling_idx, left));
|
||||
}
|
||||
}
|
||||
|
||||
peak_idx = peak_idx.parent();
|
||||
new = Rpo256::merge(&[left, right]);
|
||||
target <<= 1;
|
||||
}
|
||||
|
||||
debug_assert!(peak_count == (merges.count_ones() as usize));
|
||||
|
||||
// restore the peaks order
|
||||
self.peaks.reverse();
|
||||
// remove the merged peaks
|
||||
self.peaks.truncate(self.peaks.len() - peak_count);
|
||||
// add the newly computed peak, the result of the merges
|
||||
self.peaks.push(new);
|
||||
}
|
||||
|
||||
// The rest of the update data is composed of peaks. None of these elements can contain
|
||||
// tracked elements because the peaks were unknown, and it is not possible to add elements
|
||||
// for tacking without authenticating it to a peak.
|
||||
self.peaks.extend_from_slice(&delta.data[update_count..]);
|
||||
self.forest = delta.forest;
|
||||
|
||||
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
|
||||
|
||||
Ok(inserted_nodes)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
|
||||
/// index.
|
||||
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
|
||||
if node_index.is_leaf() {
|
||||
self.nodes.contains_key(&node_index.sibling())
|
||||
} else {
|
||||
let left_child = node_index.left_child();
|
||||
let right_child = node_index.right_child();
|
||||
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<MmrPeaks> for PartialMmr {
|
||||
fn from(peaks: MmrPeaks) -> Self {
|
||||
Self::from_peaks(peaks)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PartialMmr> for MmrPeaks {
|
||||
fn from(partial_mmr: PartialMmr) -> Self {
|
||||
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
|
||||
// matches the number of peaks, as required by the [MmrPeaks]
|
||||
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&MmrPeaks> for PartialMmr {
|
||||
fn from(peaks: &MmrPeaks) -> Self {
|
||||
Self::from_peaks(peaks.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&PartialMmr> for MmrPeaks {
|
||||
fn from(partial_mmr: &PartialMmr) -> Self {
|
||||
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
|
||||
// matches the number of peaks, as required by the [MmrPeaks]
|
||||
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks.clone()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
/// An iterator over every inner node of the [PartialMmr].
|
||||
pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
|
||||
nodes: &'a NodeMap,
|
||||
leaves: I,
|
||||
stack: Vec<(InOrderIndex, RpoDigest)>,
|
||||
seen_nodes: BTreeSet<InOrderIndex>,
|
||||
}
|
||||
|
||||
impl<I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'_, I> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some((idx, node)) = self.stack.pop() {
|
||||
let parent_idx = idx.parent();
|
||||
let new_node = self.seen_nodes.insert(parent_idx);
|
||||
|
||||
// if we haven't seen this node's parent before, and the node has a sibling, return
|
||||
// the inner node defined by the parent of this node, and move up the branch
|
||||
if new_node {
|
||||
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
|
||||
let (left, right) = if parent_idx.left_child() == idx {
|
||||
(node, *sibling)
|
||||
} else {
|
||||
(*sibling, node)
|
||||
};
|
||||
let parent = Rpo256::merge(&[left, right]);
|
||||
let inner_node = InnerNodeInfo { value: parent, left, right };
|
||||
|
||||
self.stack.push((parent_idx, parent));
|
||||
return Some(inner_node);
|
||||
}
|
||||
}
|
||||
|
||||
// the previous leaf has been processed, try to process the next leaf
|
||||
if let Some((pos, leaf)) = self.leaves.next() {
|
||||
let idx = InOrderIndex::from_leaf_pos(pos);
|
||||
self.stack.push((idx, leaf));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for PartialMmr {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.forest.write_into(target);
|
||||
self.peaks.write_into(target);
|
||||
self.nodes.write_into(target);
|
||||
target.write_bool(self.track_latest);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PartialMmr {
|
||||
fn read_from<R: winter_utils::ByteReader>(
|
||||
source: &mut R,
|
||||
) -> Result<Self, winter_utils::DeserializationError> {
|
||||
let forest = usize::read_from(source)?;
|
||||
let peaks = Vec::<RpoDigest>::read_from(source)?;
|
||||
let nodes = NodeMap::read_from(source)?;
|
||||
let track_latest = source.read_bool()?;
|
||||
|
||||
Ok(Self { forest, peaks, nodes, track_latest })
|
||||
}
|
||||
}
|
||||
|
||||
// UTILS
|
||||
// ================================================================================================
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
|
||||
/// in it.
|
||||
fn forest_to_root_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
// Add the count for the parent nodes that separate each tree. These are allocated but
|
||||
// currently empty, and correspond to the nodes that will be used once the trees are merged.
|
||||
let open_trees = (forest.count_ones() - 1) as usize;
|
||||
|
||||
// Remove the count of the right subtree of the target tree, target tree root index comes
|
||||
// before the subtree for the in-order tree walk.
|
||||
let right_subtree_count = ((1u32 << forest.trailing_zeros()) - 1) as usize;
|
||||
|
||||
let idx = nodes + open_trees - right_subtree_count;
|
||||
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
/// Given the description of a `forest`, returns the index of the right most element.
|
||||
fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
|
||||
// Count total size of all trees in the forest.
|
||||
let nodes = nodes_in_forest(forest);
|
||||
|
||||
// Add the count for the parent nodes that separate each tree. These are allocated but
|
||||
// currently empty, and correspond to the nodes that will be used once the trees are merged.
|
||||
let open_trees = (forest.count_ones() - 1) as usize;
|
||||
|
||||
let idx = nodes + open_trees;
|
||||
|
||||
InOrderIndex::new(idx.try_into().unwrap())
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use winter_utils::{Deserializable, Serializable};
|
||||
|
||||
use super::{
|
||||
forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr,
|
||||
RpoDigest,
|
||||
};
|
||||
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_root_index() {
|
||||
fn idx(pos: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(pos.try_into().unwrap())
|
||||
}
|
||||
|
||||
// When there is a single tree in the forest, the index is equivalent to the number of
|
||||
// leaves in that tree, which is `2^n`.
|
||||
assert_eq!(forest_to_root_index(0b0001), idx(1));
|
||||
assert_eq!(forest_to_root_index(0b0010), idx(2));
|
||||
assert_eq!(forest_to_root_index(0b0100), idx(4));
|
||||
assert_eq!(forest_to_root_index(0b1000), idx(8));
|
||||
|
||||
assert_eq!(forest_to_root_index(0b0011), idx(5));
|
||||
assert_eq!(forest_to_root_index(0b0101), idx(9));
|
||||
assert_eq!(forest_to_root_index(0b1001), idx(17));
|
||||
assert_eq!(forest_to_root_index(0b0111), idx(13));
|
||||
assert_eq!(forest_to_root_index(0b1011), idx(21));
|
||||
assert_eq!(forest_to_root_index(0b1111), idx(29));
|
||||
|
||||
assert_eq!(forest_to_root_index(0b0110), idx(10));
|
||||
assert_eq!(forest_to_root_index(0b1010), idx(18));
|
||||
assert_eq!(forest_to_root_index(0b1100), idx(20));
|
||||
assert_eq!(forest_to_root_index(0b1110), idx(26));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forest_to_rightmost_index() {
|
||||
fn idx(pos: usize) -> InOrderIndex {
|
||||
InOrderIndex::new(pos.try_into().unwrap())
|
||||
}
|
||||
|
||||
for forest in 1..256 {
|
||||
assert!(forest_to_rightmost_index(forest).inner() % 2 == 1, "Leaves are always odd");
|
||||
}
|
||||
|
||||
assert_eq!(forest_to_rightmost_index(0b0001), idx(1));
|
||||
assert_eq!(forest_to_rightmost_index(0b0010), idx(3));
|
||||
assert_eq!(forest_to_rightmost_index(0b0011), idx(5));
|
||||
assert_eq!(forest_to_rightmost_index(0b0100), idx(7));
|
||||
assert_eq!(forest_to_rightmost_index(0b0101), idx(9));
|
||||
assert_eq!(forest_to_rightmost_index(0b0110), idx(11));
|
||||
assert_eq!(forest_to_rightmost_index(0b0111), idx(13));
|
||||
assert_eq!(forest_to_rightmost_index(0b1000), idx(15));
|
||||
assert_eq!(forest_to_rightmost_index(0b1001), idx(17));
|
||||
assert_eq!(forest_to_rightmost_index(0b1010), idx(19));
|
||||
assert_eq!(forest_to_rightmost_index(0b1011), idx(21));
|
||||
assert_eq!(forest_to_rightmost_index(0b1100), idx(23));
|
||||
assert_eq!(forest_to_rightmost_index(0b1101), idx(25));
|
||||
assert_eq!(forest_to_rightmost_index(0b1110), idx(27));
|
||||
assert_eq!(forest_to_rightmost_index(0b1111), idx(29));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_apply_delta() {
|
||||
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
|
||||
let mut mmr = Mmr::default();
|
||||
(0..10).for_each(|i| mmr.add(int_to_node(i)));
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
|
||||
// add authentication path for position 1 and 8
|
||||
{
|
||||
let node = mmr.get(1).unwrap();
|
||||
let proof = mmr.open(1).unwrap();
|
||||
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let node = mmr.get(8).unwrap();
|
||||
let proof = mmr.open(8).unwrap();
|
||||
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
|
||||
}
|
||||
|
||||
// add 2 more nodes into the MMR and validate apply_delta()
|
||||
(10..12).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
|
||||
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
|
||||
mmr.add(int_to_node(12));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
{
|
||||
let node = mmr.get(12).unwrap();
|
||||
let proof = mmr.open(12).unwrap();
|
||||
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
|
||||
assert!(partial_mmr.track_latest);
|
||||
}
|
||||
|
||||
// by this point we are tracking authentication paths for positions: 1, 8, and 12
|
||||
|
||||
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
|
||||
(13..16).for_each(|i| mmr.add(int_to_node(i)));
|
||||
validate_apply_delta(&mmr, &mut partial_mmr);
|
||||
}
|
||||
|
||||
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
|
||||
let tracked_leaves = partial
|
||||
.nodes
|
||||
.iter()
|
||||
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
|
||||
.collect::<Vec<_>>();
|
||||
let nodes_before = partial.nodes.clone();
|
||||
|
||||
// compute and apply delta
|
||||
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
|
||||
let nodes_delta = partial.apply(delta).unwrap();
|
||||
|
||||
// new peaks were computed correctly
|
||||
assert_eq!(mmr.peaks(), partial.peaks());
|
||||
|
||||
let mut expected_nodes = nodes_before;
|
||||
for (key, value) in nodes_delta {
|
||||
// nodes should not be duplicated
|
||||
assert!(expected_nodes.insert(key, value).is_none());
|
||||
}
|
||||
|
||||
// new nodes should be a combination of original nodes and delta
|
||||
assert_eq!(expected_nodes, partial.nodes);
|
||||
|
||||
// make sure tracked leaves open to the same proofs as in the underlying MMR
|
||||
for index in tracked_leaves {
|
||||
let index_value: u64 = index.into();
|
||||
let pos = index_value / 2;
|
||||
let proof1 = partial.open(pos as usize).unwrap().unwrap();
|
||||
let proof2 = mmr.open(pos as usize).unwrap();
|
||||
assert_eq!(proof1, proof2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_inner_nodes_iterator() {
|
||||
// build the MMR
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let first_peak = mmr.peaks().peaks()[0];
|
||||
|
||||
// -- test single tree ----------------------------
|
||||
|
||||
// get path and node for position 1
|
||||
let node1 = mmr.get(1).unwrap();
|
||||
let proof1 = mmr.open(1).unwrap();
|
||||
|
||||
// create partial MMR and add authentication path to node at position 1
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// empty iterator should have no nodes
|
||||
assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
|
||||
// -- test no duplicates --------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
|
||||
let node0 = mmr.get(0).unwrap();
|
||||
let proof0 = mmr.open(0).unwrap();
|
||||
|
||||
let node2 = mmr.get(2).unwrap();
|
||||
let proof2 = mmr.open(2).unwrap();
|
||||
|
||||
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// make sure there are no duplicates
|
||||
let leaves = [(0, node0), (1, node1), (2, node2)];
|
||||
let mut nodes = BTreeSet::new();
|
||||
for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
|
||||
assert!(nodes.insert(node.value));
|
||||
}
|
||||
|
||||
// and also that the store is still be built correctly
|
||||
store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
|
||||
|
||||
let index0 = NodeIndex::new(2, 0).unwrap();
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index2 = NodeIndex::new(2, 2).unwrap();
|
||||
|
||||
let path0 = store.get_path(first_peak, index0).unwrap().path;
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path2 = store.get_path(first_peak, index2).unwrap().path;
|
||||
|
||||
assert_eq!(path0, proof0.merkle_path);
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path2, proof2.merkle_path);
|
||||
|
||||
// -- test multiple trees -------------------------
|
||||
|
||||
// build the partial MMR
|
||||
let mut partial_mmr: PartialMmr = mmr.peaks().into();
|
||||
|
||||
let node5 = mmr.get(5).unwrap();
|
||||
let proof5 = mmr.open(5).unwrap();
|
||||
|
||||
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
|
||||
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
|
||||
|
||||
// build Merkle store from authentication paths in partial MMR
|
||||
let mut store: MerkleStore = MerkleStore::new();
|
||||
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
|
||||
|
||||
let index1 = NodeIndex::new(2, 1).unwrap();
|
||||
let index5 = NodeIndex::new(1, 1).unwrap();
|
||||
|
||||
let second_peak = mmr.peaks().peaks()[1];
|
||||
|
||||
let path1 = store.get_path(first_peak, index1).unwrap().path;
|
||||
let path5 = store.get_path(second_peak, index5).unwrap().path;
|
||||
|
||||
assert_eq!(path1, proof1.merkle_path);
|
||||
assert_eq!(path5, proof5.merkle_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_without_track() {
|
||||
let mut mmr = Mmr::default();
|
||||
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
|
||||
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
|
||||
|
||||
for el in (0..256).map(int_to_node) {
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, false);
|
||||
|
||||
assert_eq!(mmr.peaks(), partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_with_track() {
|
||||
let mut mmr = Mmr::default();
|
||||
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
|
||||
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
|
||||
|
||||
for i in 0..256 {
|
||||
let el = int_to_node(i);
|
||||
mmr.add(el);
|
||||
partial_mmr.add(el, true);
|
||||
|
||||
assert_eq!(mmr.peaks(), partial_mmr.peaks());
|
||||
assert_eq!(mmr.forest(), partial_mmr.forest());
|
||||
|
||||
for pos in 0..i {
|
||||
let mmr_proof = mmr.open(pos as usize).unwrap();
|
||||
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
|
||||
assert_eq!(mmr_proof, partialmmr_proof);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_add_existing_track() {
|
||||
let mut mmr = Mmr::from((0..7).map(int_to_node));
|
||||
|
||||
// derive a partial Mmr from it which tracks authentication path to leaf 5
|
||||
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
|
||||
let path_to_5 = mmr.open(5).unwrap().merkle_path;
|
||||
let leaf_at_5 = mmr.get(5).unwrap();
|
||||
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
|
||||
|
||||
// add a new leaf to both Mmr and partial Mmr
|
||||
let leaf_at_7 = int_to_node(7);
|
||||
mmr.add(leaf_at_7);
|
||||
partial_mmr.add(leaf_at_7, false);
|
||||
|
||||
// the openings should be the same
|
||||
assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_serialization() {
|
||||
let mmr = Mmr::from((0..7).map(int_to_node));
|
||||
let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
|
||||
|
||||
let bytes = partial_mmr.to_bytes();
|
||||
let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
|
||||
|
||||
assert_eq!(partial_mmr, decoded);
|
||||
}
|
||||
}
|
162
src/merkle/mmr/peaks.rs
Normal file
162
src/merkle/mmr/peaks.rs
Normal file
|
@ -0,0 +1,162 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
|
||||
|
||||
// MMR PEAKS
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrPeaks {
|
||||
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
|
||||
/// happens because the number of peaks goes up-and-down as the structure is used causing
|
||||
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
|
||||
/// has a power-of-two number of leaves there is a single peak.
|
||||
///
|
||||
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
|
||||
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
|
||||
/// bits in `num_leaves` conveniently encode the size of each individual tree.
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number of peaks,
|
||||
/// in this case there are 2 peaks. The 0-indexed least-significant position of the bit
|
||||
/// determines the number of elements of a tree, so the rightmost tree has `2**0` elements
|
||||
/// and the left most has `2**2`.
|
||||
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the leftmost tree has
|
||||
/// `2**3=8` elements, and the right most has `2**2=4` elements.
|
||||
num_leaves: usize,
|
||||
|
||||
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
|
||||
/// leaves, starting from the peak with most children, to the one with least.
|
||||
///
|
||||
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
|
||||
peaks: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl MmrPeaks {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
|
||||
/// leaves in the underlying MMR.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
|
||||
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
|
||||
if num_leaves.count_ones() as usize != peaks.len() {
|
||||
return Err(MmrError::InvalidPeaks(format!(
|
||||
"number of one bits in leaves is {} which does not equal peak length {}",
|
||||
num_leaves.count_ones(),
|
||||
peaks.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self { num_leaves, peaks })
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a count of leaves in the underlying MMR.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.num_leaves
|
||||
}
|
||||
|
||||
/// Returns the number of peaks of the underlying MMR.
|
||||
pub fn num_peaks(&self) -> usize {
|
||||
self.peaks.len()
|
||||
}
|
||||
|
||||
/// Returns the list of peaks of the underlying MMR.
|
||||
pub fn peaks(&self) -> &[RpoDigest] {
|
||||
&self.peaks
|
||||
}
|
||||
|
||||
/// Returns the peak by the provided index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided peak index is greater or equal to the current number of
|
||||
/// peaks in the Mmr.
|
||||
pub fn get_peak(&self, peak_idx: usize) -> Result<&RpoDigest, MmrError> {
|
||||
self.peaks
|
||||
.get(peak_idx)
|
||||
.ok_or(MmrError::PeakOutOfBounds { peak_idx, peaks_len: self.peaks.len() })
|
||||
}
|
||||
|
||||
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
|
||||
/// the underlying MMR.
|
||||
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
|
||||
(self.num_leaves, self.peaks)
|
||||
}
|
||||
|
||||
/// Hashes the peaks.
|
||||
///
|
||||
/// The procedure will:
|
||||
/// - Flatten and pad the peaks to a vector of Felts.
|
||||
/// - Hash the vector of Felts.
|
||||
pub fn hash_peaks(&self) -> RpoDigest {
|
||||
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
|
||||
}
|
||||
|
||||
/// Verifies the Merkle opening proof.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - provided opening proof is invalid.
|
||||
/// - Mmr root value computed using the provided leaf value differs from the actual one.
|
||||
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> Result<(), MmrError> {
|
||||
let root = self.get_peak(opening.peak_index())?;
|
||||
opening
|
||||
.merkle_path
|
||||
.verify(opening.relative_pos() as u64, value, root)
|
||||
.map_err(MmrError::InvalidMerklePath)
|
||||
}
|
||||
|
||||
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
|
||||
///
|
||||
/// The procedure will:
|
||||
/// - Flatten the vector of Words into a vector of Felts.
|
||||
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
|
||||
/// padding.
|
||||
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of hashing.
|
||||
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
|
||||
let num_peaks = self.peaks.len();
|
||||
|
||||
// To achieve the padding rules above we calculate the length of the final vector.
|
||||
// This is calculated as the number of field elements. Each peak is 4 field elements.
|
||||
// The length is calculated as follows:
|
||||
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires 64
|
||||
// field elements.
|
||||
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
|
||||
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
|
||||
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded
|
||||
// and as such requires `num_peaks * 4` field elements.
|
||||
let len = if num_peaks < 16 {
|
||||
64
|
||||
} else if num_peaks % 2 == 1 {
|
||||
(num_peaks + 1) * 4
|
||||
} else {
|
||||
num_peaks * 4
|
||||
};
|
||||
|
||||
let mut elements = Vec::with_capacity(len);
|
||||
elements.extend_from_slice(
|
||||
&self
|
||||
.peaks
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|digest| digest.into())
|
||||
.collect::<Vec<Word>>()
|
||||
.concat(),
|
||||
);
|
||||
elements.resize(len, ZERO);
|
||||
elements
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MmrPeaks> for Vec<RpoDigest> {
|
||||
fn from(peaks: MmrPeaks) -> Self {
|
||||
peaks.peaks
|
||||
}
|
||||
}
|
106
src/merkle/mmr/proof.rs
Normal file
106
src/merkle/mmr/proof.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
/// The representation of a single Merkle path.
|
||||
use super::super::MerklePath;
|
||||
use super::{full::high_bitmask, leaf_to_corresponding_tree};
|
||||
|
||||
// MMR PROOF
|
||||
// ================================================================================================
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MmrProof {
|
||||
/// The state of the MMR when the MmrProof was created.
|
||||
pub forest: usize,
|
||||
|
||||
/// The position of the leaf value on this MmrProof.
|
||||
pub position: usize,
|
||||
|
||||
/// The Merkle opening, starting from the value's sibling up to and excluding the root of the
|
||||
/// responsible tree.
|
||||
pub merkle_path: MerklePath,
|
||||
}
|
||||
|
||||
impl MmrProof {
|
||||
/// Converts the leaf global position into a local position that can be used to verify the
|
||||
/// merkle_path.
|
||||
pub fn relative_pos(&self) -> usize {
|
||||
let tree_bit = leaf_to_corresponding_tree(self.position, self.forest)
|
||||
.expect("position must be part of the forest");
|
||||
let forest_before = self.forest & high_bitmask(tree_bit + 1);
|
||||
self.position - forest_before
|
||||
}
|
||||
|
||||
/// Returns index of the MMR peak against which the Merkle path in this proof can be verified.
|
||||
pub fn peak_index(&self) -> usize {
|
||||
let root = leaf_to_corresponding_tree(self.position, self.forest)
|
||||
.expect("position must be part of the forest");
|
||||
let smaller_peak_mask = 2_usize.pow(root) as usize - 1;
|
||||
let num_smaller_peaks = (self.forest & smaller_peak_mask).count_ones();
|
||||
(self.forest.count_ones() - num_smaller_peaks - 1) as usize
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{MerklePath, MmrProof};
|
||||
|
||||
#[test]
|
||||
fn test_peak_index() {
|
||||
// --- single peak forest ---------------------------------------------
|
||||
let forest = 11;
|
||||
|
||||
// the first 4 leaves belong to peak 0
|
||||
for position in 0..8 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// --- forest with non-consecutive peaks ------------------------------
|
||||
let forest = 11;
|
||||
|
||||
// the first 8 leaves belong to peak 0
|
||||
for position in 0..8 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// the next 2 leaves belong to peak 1
|
||||
for position in 8..10 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 1);
|
||||
}
|
||||
|
||||
// the last leaf is the peak 2
|
||||
let proof = make_dummy_proof(forest, 10);
|
||||
assert_eq!(proof.peak_index(), 2);
|
||||
|
||||
// --- forest with consecutive peaks ----------------------------------
|
||||
let forest = 7;
|
||||
|
||||
// the first 4 leaves belong to peak 0
|
||||
for position in 0..4 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 0);
|
||||
}
|
||||
|
||||
// the next 2 leaves belong to peak 1
|
||||
for position in 4..6 {
|
||||
let proof = make_dummy_proof(forest, position);
|
||||
assert_eq!(proof.peak_index(), 1);
|
||||
}
|
||||
|
||||
// the last leaf is the peak 2
|
||||
let proof = make_dummy_proof(forest, 6);
|
||||
assert_eq!(proof.peak_index(), 2);
|
||||
}
|
||||
|
||||
fn make_dummy_proof(forest: usize, position: usize) -> MmrProof {
|
||||
MmrProof {
|
||||
forest,
|
||||
position,
|
||||
merkle_path: MerklePath::default(),
|
||||
}
|
||||
}
|
||||
}
|
890
src/merkle/mmr/tests.rs
Normal file
890
src/merkle/mmr/tests.rs
Normal file
|
@ -0,0 +1,890 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use super::{
|
||||
super::{InnerNodeInfo, Rpo256, RpoDigest},
|
||||
bit::TrueBitPositionIterator,
|
||||
full::high_bitmask,
|
||||
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
|
||||
Felt, Word,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_position_equal_or_higher_than_leafs_is_never_contained() {
|
||||
let empty_forest = 0;
|
||||
for pos in 1..1024 {
|
||||
// pos is index, 0 based
|
||||
// tree is a length counter, 1 based
|
||||
// so a valid pos is always smaller, not equal, to tree
|
||||
assert_eq!(leaf_to_corresponding_tree(pos, pos), None);
|
||||
assert_eq!(leaf_to_corresponding_tree(pos, pos - 1), None);
|
||||
// and empty forest has no trees, so no position is valid
|
||||
assert_eq!(leaf_to_corresponding_tree(pos, empty_forest), None);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_position_zero_is_always_contained_by_the_highest_tree() {
|
||||
for leaves in 1..1024usize {
|
||||
let tree = leaves.ilog2();
|
||||
assert_eq!(leaf_to_corresponding_tree(0, leaves), Some(tree));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaf_to_corresponding_tree() {
|
||||
assert_eq!(leaf_to_corresponding_tree(0, 0b0001), Some(0));
|
||||
assert_eq!(leaf_to_corresponding_tree(0, 0b0010), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(0, 0b0011), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(0, 0b1011), Some(3));
|
||||
|
||||
// position one is always owned by the left-most tree
|
||||
assert_eq!(leaf_to_corresponding_tree(1, 0b0010), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(1, 0b0011), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(1, 0b1011), Some(3));
|
||||
|
||||
// position two starts as its own root, and then it is merged with the left-most tree
|
||||
assert_eq!(leaf_to_corresponding_tree(2, 0b0011), Some(0));
|
||||
assert_eq!(leaf_to_corresponding_tree(2, 0b0100), Some(2));
|
||||
assert_eq!(leaf_to_corresponding_tree(2, 0b1011), Some(3));
|
||||
|
||||
// position tree is merged on the left-most tree
|
||||
assert_eq!(leaf_to_corresponding_tree(3, 0b0011), None);
|
||||
assert_eq!(leaf_to_corresponding_tree(3, 0b0100), Some(2));
|
||||
assert_eq!(leaf_to_corresponding_tree(3, 0b1011), Some(3));
|
||||
|
||||
assert_eq!(leaf_to_corresponding_tree(4, 0b0101), Some(0));
|
||||
assert_eq!(leaf_to_corresponding_tree(4, 0b0110), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(4, 0b0111), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(4, 0b1000), Some(3));
|
||||
|
||||
assert_eq!(leaf_to_corresponding_tree(12, 0b01101), Some(0));
|
||||
assert_eq!(leaf_to_corresponding_tree(12, 0b01110), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(12, 0b01111), Some(1));
|
||||
assert_eq!(leaf_to_corresponding_tree(12, 0b10000), Some(4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_bitmask() {
|
||||
assert_eq!(high_bitmask(0), usize::MAX);
|
||||
assert_eq!(high_bitmask(1), usize::MAX << 1);
|
||||
assert_eq!(high_bitmask(usize::BITS - 2), 0b11usize.rotate_right(2));
|
||||
assert_eq!(high_bitmask(usize::BITS - 1), 0b1usize.rotate_right(1));
|
||||
assert_eq!(high_bitmask(usize::BITS), 0, "overflow should be handled");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nodes_in_forest() {
|
||||
assert_eq!(nodes_in_forest(0b0000), 0);
|
||||
assert_eq!(nodes_in_forest(0b0001), 1);
|
||||
assert_eq!(nodes_in_forest(0b0010), 3);
|
||||
assert_eq!(nodes_in_forest(0b0011), 4);
|
||||
assert_eq!(nodes_in_forest(0b0100), 7);
|
||||
assert_eq!(nodes_in_forest(0b0101), 8);
|
||||
assert_eq!(nodes_in_forest(0b0110), 10);
|
||||
assert_eq!(nodes_in_forest(0b0111), 11);
|
||||
assert_eq!(nodes_in_forest(0b1000), 15);
|
||||
assert_eq!(nodes_in_forest(0b1001), 16);
|
||||
assert_eq!(nodes_in_forest(0b1010), 18);
|
||||
assert_eq!(nodes_in_forest(0b1011), 19);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nodes_in_forest_single_bit() {
|
||||
assert_eq!(nodes_in_forest(2usize.pow(0)), 2usize.pow(1) - 1);
|
||||
assert_eq!(nodes_in_forest(2usize.pow(1)), 2usize.pow(2) - 1);
|
||||
assert_eq!(nodes_in_forest(2usize.pow(2)), 2usize.pow(3) - 1);
|
||||
assert_eq!(nodes_in_forest(2usize.pow(3)), 2usize.pow(4) - 1);
|
||||
|
||||
for bit in 0..(usize::BITS - 1) {
|
||||
let size = 2usize.pow(bit + 1) - 1;
|
||||
assert_eq!(nodes_in_forest(1usize << bit), size);
|
||||
}
|
||||
}
|
||||
|
||||
const LEAVES: [RpoDigest; 7] = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_mmr_simple() {
|
||||
let mut postorder = vec![
|
||||
LEAVES[0],
|
||||
LEAVES[1],
|
||||
merge(LEAVES[0], LEAVES[1]),
|
||||
LEAVES[2],
|
||||
LEAVES[3],
|
||||
merge(LEAVES[2], LEAVES[3]),
|
||||
];
|
||||
postorder.push(merge(postorder[2], postorder[5]));
|
||||
postorder.push(LEAVES[4]);
|
||||
postorder.push(LEAVES[5]);
|
||||
postorder.push(merge(LEAVES[4], LEAVES[5]));
|
||||
postorder.push(LEAVES[6]);
|
||||
|
||||
let mut mmr = Mmr::new();
|
||||
assert_eq!(mmr.forest(), 0);
|
||||
assert_eq!(mmr.nodes.len(), 0);
|
||||
|
||||
mmr.add(LEAVES[0]);
|
||||
assert_eq!(mmr.forest(), 1);
|
||||
assert_eq!(mmr.nodes.len(), 1);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 1);
|
||||
assert_eq!(acc.peaks(), &[postorder[0]]);
|
||||
|
||||
mmr.add(LEAVES[1]);
|
||||
assert_eq!(mmr.forest(), 2);
|
||||
assert_eq!(mmr.nodes.len(), 3);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 2);
|
||||
assert_eq!(acc.peaks(), &[postorder[2]]);
|
||||
|
||||
mmr.add(LEAVES[2]);
|
||||
assert_eq!(mmr.forest(), 3);
|
||||
assert_eq!(mmr.nodes.len(), 4);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 3);
|
||||
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
|
||||
|
||||
mmr.add(LEAVES[3]);
|
||||
assert_eq!(mmr.forest(), 4);
|
||||
assert_eq!(mmr.nodes.len(), 7);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 4);
|
||||
assert_eq!(acc.peaks(), &[postorder[6]]);
|
||||
|
||||
mmr.add(LEAVES[4]);
|
||||
assert_eq!(mmr.forest(), 5);
|
||||
assert_eq!(mmr.nodes.len(), 8);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 5);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
|
||||
|
||||
mmr.add(LEAVES[5]);
|
||||
assert_eq!(mmr.forest(), 6);
|
||||
assert_eq!(mmr.nodes.len(), 10);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 6);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
|
||||
|
||||
mmr.add(LEAVES[6]);
|
||||
assert_eq!(mmr.forest(), 7);
|
||||
assert_eq!(mmr.nodes.len(), 11);
|
||||
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
|
||||
|
||||
let acc = mmr.peaks();
|
||||
assert_eq!(acc.num_leaves(), 7);
|
||||
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_open() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let h01 = merge(LEAVES[0], LEAVES[1]);
|
||||
let h23 = merge(LEAVES[2], LEAVES[3]);
|
||||
|
||||
// node at pos 7 is the root
|
||||
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
|
||||
|
||||
// node at pos 6 is the root
|
||||
let empty: MerklePath = MerklePath::new(vec![]);
|
||||
let opening = mmr
|
||||
.open(6)
|
||||
.expect("Element 6 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, empty);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 6);
|
||||
mmr.peaks().verify(LEAVES[6], opening).unwrap();
|
||||
|
||||
// nodes 4,5 are depth 1
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
|
||||
let opening = mmr
|
||||
.open(5)
|
||||
.expect("Element 5 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 5);
|
||||
mmr.peaks().verify(LEAVES[5], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
|
||||
let opening = mmr
|
||||
.open(4)
|
||||
.expect("Element 4 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 4);
|
||||
mmr.peaks().verify(LEAVES[4], opening).unwrap();
|
||||
|
||||
// nodes 0,1,2,3 are detph 2
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
|
||||
let opening = mmr
|
||||
.open(3)
|
||||
.expect("Element 3 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 3);
|
||||
mmr.peaks().verify(LEAVES[3], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
|
||||
let opening = mmr
|
||||
.open(2)
|
||||
.expect("Element 2 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 2);
|
||||
mmr.peaks().verify(LEAVES[2], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
|
||||
let opening = mmr
|
||||
.open(1)
|
||||
.expect("Element 1 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 1);
|
||||
mmr.peaks().verify(LEAVES[1], opening).unwrap();
|
||||
|
||||
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
|
||||
let opening = mmr
|
||||
.open(0)
|
||||
.expect("Element 0 is contained in the tree, expected an opening result.");
|
||||
assert_eq!(opening.merkle_path, root_to_path);
|
||||
assert_eq!(opening.forest, mmr.forest);
|
||||
assert_eq!(opening.position, 0);
|
||||
mmr.peaks().verify(LEAVES[0], opening).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_open_older_version() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
fn is_even(v: &usize) -> bool {
|
||||
v & 1 == 0
|
||||
}
|
||||
|
||||
// merkle path of a node is empty if there are no elements to pair with it
|
||||
for pos in (0..mmr.forest()).filter(is_even) {
|
||||
let forest = pos + 1;
|
||||
let proof = mmr.open_at(pos, forest).unwrap();
|
||||
assert_eq!(proof.forest, forest);
|
||||
assert_eq!(proof.merkle_path.nodes(), []);
|
||||
assert_eq!(proof.position, pos);
|
||||
}
|
||||
|
||||
// openings match that of a merkle tree
|
||||
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||
for forest in 4..=LEAVES.len() {
|
||||
for pos in 0..4 {
|
||||
let idx = NodeIndex::new(2, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
let proof = mmr.open_at(pos as usize, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||
for forest in 6..=LEAVES.len() {
|
||||
for pos in 0..2 {
|
||||
let idx = NodeIndex::new(1, pos).unwrap();
|
||||
let path = mtree.get_path(idx).unwrap();
|
||||
// account for the bigger tree with 4 elements
|
||||
let mmr_pos = (pos + 4) as usize;
|
||||
let proof = mmr.open_at(mmr_pos, forest).unwrap();
|
||||
assert_eq!(path, proof.merkle_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests the openings of a simple Mmr with a single tree of depth 8.
|
||||
#[test]
|
||||
fn test_mmr_open_eight() {
|
||||
let leaves = [
|
||||
int_to_node(0),
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
];
|
||||
|
||||
let mtree: MerkleTree = leaves.as_slice().try_into().unwrap();
|
||||
let forest = leaves.len();
|
||||
let mmr: Mmr = leaves.into();
|
||||
let root = mtree.root();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
|
||||
let position = 7;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
|
||||
}
|
||||
|
||||
/// Tests the openings of Mmr with a 3 trees of depths 4, 2, and 1.
|
||||
#[test]
|
||||
fn test_mmr_open_seven() {
|
||||
let mtree1: MerkleTree = LEAVES[..4].try_into().unwrap();
|
||||
let mtree2: MerkleTree = LEAVES[4..6].try_into().unwrap();
|
||||
|
||||
let forest = LEAVES.len();
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let position = 0;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 1;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 2;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 3;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath =
|
||||
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
|
||||
|
||||
let position = 4;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 5;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
|
||||
|
||||
let position = 6;
|
||||
let proof = mmr.open(position).unwrap();
|
||||
let merkle_path: MerklePath = [].as_ref().into();
|
||||
assert_eq!(proof, MmrProof { forest, position, merkle_path });
|
||||
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_get() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
assert_eq!(mmr.get(0).unwrap(), LEAVES[0], "value at pos 0 must correspond");
|
||||
assert_eq!(mmr.get(1).unwrap(), LEAVES[1], "value at pos 1 must correspond");
|
||||
assert_eq!(mmr.get(2).unwrap(), LEAVES[2], "value at pos 2 must correspond");
|
||||
assert_eq!(mmr.get(3).unwrap(), LEAVES[3], "value at pos 3 must correspond");
|
||||
assert_eq!(mmr.get(4).unwrap(), LEAVES[4], "value at pos 4 must correspond");
|
||||
assert_eq!(mmr.get(5).unwrap(), LEAVES[5], "value at pos 5 must correspond");
|
||||
assert_eq!(mmr.get(6).unwrap(), LEAVES[6], "value at pos 6 must correspond");
|
||||
assert!(mmr.get(7).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_invariants() {
|
||||
let mut mmr = Mmr::new();
|
||||
for v in 1..=1028 {
|
||||
mmr.add(int_to_node(v));
|
||||
let accumulator = mmr.peaks();
|
||||
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
|
||||
assert_eq!(
|
||||
v as usize,
|
||||
accumulator.num_leaves(),
|
||||
"MMR and its accumulator must match leaves count"
|
||||
);
|
||||
assert_eq!(
|
||||
accumulator.num_leaves().count_ones() as usize,
|
||||
accumulator.peaks().len(),
|
||||
"bits on leaves must match the number of peaks"
|
||||
);
|
||||
|
||||
let expected_nodes: usize = TrueBitPositionIterator::new(mmr.forest())
|
||||
.map(|bit_pos| nodes_in_forest(1 << bit_pos))
|
||||
.sum();
|
||||
|
||||
assert_eq!(
|
||||
expected_nodes,
|
||||
mmr.nodes.len(),
|
||||
"the sum of every tree size must be equal to the number of nodes in the MMR (forest: {:b})",
|
||||
mmr.forest(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bit_position_iterator() {
|
||||
assert_eq!(TrueBitPositionIterator::new(0).count(), 0);
|
||||
assert_eq!(TrueBitPositionIterator::new(0).rev().count(), 0);
|
||||
|
||||
assert_eq!(TrueBitPositionIterator::new(1).collect::<Vec<u32>>(), vec![0]);
|
||||
assert_eq!(TrueBitPositionIterator::new(1).rev().collect::<Vec<u32>>(), vec![0],);
|
||||
|
||||
assert_eq!(TrueBitPositionIterator::new(2).collect::<Vec<u32>>(), vec![1]);
|
||||
assert_eq!(TrueBitPositionIterator::new(2).rev().collect::<Vec<u32>>(), vec![1],);
|
||||
|
||||
assert_eq!(TrueBitPositionIterator::new(3).collect::<Vec<u32>>(), vec![0, 1],);
|
||||
assert_eq!(TrueBitPositionIterator::new(3).rev().collect::<Vec<u32>>(), vec![1, 0],);
|
||||
|
||||
assert_eq!(
|
||||
TrueBitPositionIterator::new(0b11010101).collect::<Vec<u32>>(),
|
||||
vec![0, 2, 4, 6, 7],
|
||||
);
|
||||
assert_eq!(
|
||||
TrueBitPositionIterator::new(0b11010101).rev().collect::<Vec<u32>>(),
|
||||
vec![7, 6, 4, 2, 0],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_inner_nodes() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let nodes: Vec<InnerNodeInfo> = mmr.inner_nodes().collect();
|
||||
|
||||
let h01 = Rpo256::merge(&[LEAVES[0], LEAVES[1]]);
|
||||
let h23 = Rpo256::merge(&[LEAVES[2], LEAVES[3]]);
|
||||
let h0123 = Rpo256::merge(&[h01, h23]);
|
||||
let h45 = Rpo256::merge(&[LEAVES[4], LEAVES[5]]);
|
||||
let postorder = vec![
|
||||
InnerNodeInfo {
|
||||
value: h01,
|
||||
left: LEAVES[0],
|
||||
right: LEAVES[1],
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: h23,
|
||||
left: LEAVES[2],
|
||||
right: LEAVES[3],
|
||||
},
|
||||
InnerNodeInfo { value: h0123, left: h01, right: h23 },
|
||||
InnerNodeInfo {
|
||||
value: h45,
|
||||
left: LEAVES[4],
|
||||
right: LEAVES[5],
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(postorder, nodes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
let forest = 0b0001;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
|
||||
|
||||
let forest = 0b0010;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
|
||||
|
||||
let forest = 0b0011;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
|
||||
|
||||
let forest = 0b0100;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
|
||||
|
||||
let forest = 0b0101;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
|
||||
|
||||
let forest = 0b0110;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
|
||||
|
||||
let forest = 0b0111;
|
||||
let acc = mmr.peaks_at(forest).unwrap();
|
||||
assert_eq!(acc.num_leaves(), forest);
|
||||
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_hash_peaks() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.peaks();
|
||||
|
||||
let first_peak = Rpo256::merge(&[
|
||||
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
|
||||
Rpo256::merge(&[LEAVES[2], LEAVES[3]]),
|
||||
]);
|
||||
let second_peak = Rpo256::merge(&[LEAVES[4], LEAVES[5]]);
|
||||
let third_peak = LEAVES[6];
|
||||
|
||||
// minimum length is 16
|
||||
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks_hash_less_than_16() {
|
||||
let mut peaks = Vec::new();
|
||||
|
||||
for i in 0..16 {
|
||||
peaks.push(int_to_node(i));
|
||||
|
||||
let num_leaves = (1 << peaks.len()) - 1;
|
||||
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
|
||||
|
||||
// minimum length is 16
|
||||
let mut expected_peaks = peaks.clone();
|
||||
expected_peaks.resize(16, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_peaks_hash_odd() {
|
||||
let peaks: Vec<_> = (0..=17).map(int_to_node).collect();
|
||||
|
||||
let num_leaves = (1 << peaks.len()) - 1;
|
||||
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
|
||||
|
||||
// odd length bigger than 16 is padded to the next even number
|
||||
let mut expected_peaks = peaks;
|
||||
expected_peaks.resize(18, RpoDigest::default());
|
||||
assert_eq!(
|
||||
accumulator.hash_peaks(),
|
||||
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_delta() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.peaks();
|
||||
|
||||
// original_forest can't have more elements
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
|
||||
"Can not provide updates for a newer Mmr"
|
||||
);
|
||||
|
||||
// if the number of elements is the same there is no change
|
||||
assert!(
|
||||
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
|
||||
"There are no updates for the same Mmr version"
|
||||
);
|
||||
|
||||
// missing the last element added, which is itself a tree peak
|
||||
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
|
||||
|
||||
// missing the sibling to complete the tree of depth 2, and the last element
|
||||
assert_eq!(
|
||||
mmr.get_delta(5, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[5], acc.peaks()[2]],
|
||||
"one sibling, one peak"
|
||||
);
|
||||
|
||||
// missing the whole last two trees, only send the peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(4, mmr.forest()).unwrap().data,
|
||||
vec![acc.peaks()[1], acc.peaks()[2]],
|
||||
"two peaks"
|
||||
);
|
||||
|
||||
// missing the sibling to complete the first tree, and the two last trees
|
||||
assert_eq!(
|
||||
mmr.get_delta(3, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
// missing half of the first tree, only send the computed element (not the leaves), and the new
|
||||
// peaks
|
||||
assert_eq!(
|
||||
mmr.get_delta(2, mmr.forest()).unwrap().data,
|
||||
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
mmr.get_delta(1, mmr.forest()).unwrap().data,
|
||||
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
|
||||
"one sibling, two peaks"
|
||||
);
|
||||
|
||||
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_delta_old_forest() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
|
||||
// from_forest must be smaller-or-equal to to_forest
|
||||
for version in 1..=mmr.forest() {
|
||||
assert!(mmr.get_delta(version + 1, version).is_err());
|
||||
}
|
||||
|
||||
// when from_forest and to_forest are equal, there are no updates
|
||||
for version in 1..=mmr.forest() {
|
||||
let delta = mmr.get_delta(version, version).unwrap();
|
||||
assert!(delta.data.is_empty());
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
// test update which merges the odd peak to the right
|
||||
for count in 0..(mmr.forest() / 2) {
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because the Mmr is 1-indexed
|
||||
let from_forest = (count * 2) + 1;
|
||||
let to_forest = (count * 2) + 2;
|
||||
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
|
||||
|
||||
// *2 because every iteration tests a pair
|
||||
// +1 because sibling is the odd element
|
||||
let sibling = (count * 2) + 1;
|
||||
assert_eq!(delta.data, [LEAVES[sibling]]);
|
||||
assert_eq!(delta.forest, to_forest);
|
||||
}
|
||||
|
||||
let version = 4;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
|
||||
let version = 5;
|
||||
let delta = mmr.get_delta(1, version).unwrap();
|
||||
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
|
||||
assert_eq!(delta.forest, version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_simple() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let peaks = mmr.peaks();
|
||||
let mut partial: PartialMmr = peaks.clone().into();
|
||||
|
||||
// check initial state of the partial mmr
|
||||
assert_eq!(partial.peaks(), peaks);
|
||||
assert_eq!(partial.forest(), peaks.num_leaves());
|
||||
assert_eq!(partial.forest(), LEAVES.len());
|
||||
assert_eq!(partial.peaks().num_peaks(), 3);
|
||||
assert_eq!(partial.nodes.len(), 0);
|
||||
|
||||
// check state after adding tracking one element
|
||||
let proof1 = mmr.open(0).unwrap();
|
||||
let el1 = mmr.get(proof1.position).unwrap();
|
||||
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by the number of nodes in the proof
|
||||
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
|
||||
// check the values match
|
||||
let idx = InOrderIndex::from_leaf_pos(proof1.position);
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[0]);
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
|
||||
|
||||
let proof2 = mmr.open(1).unwrap();
|
||||
let el2 = mmr.get(proof2.position).unwrap();
|
||||
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
|
||||
|
||||
// check the number of nodes increased by a single element (the one that is not shared)
|
||||
assert_eq!(partial.nodes.len(), 3);
|
||||
// check the values match
|
||||
let idx = InOrderIndex::from_leaf_pos(proof2.position);
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[0]);
|
||||
let idx = idx.parent();
|
||||
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_mmr_update_single() {
|
||||
let mut full = Mmr::new();
|
||||
let zero = int_to_node(0);
|
||||
full.add(zero);
|
||||
let mut partial: PartialMmr = full.peaks().into();
|
||||
|
||||
let proof = full.open(0).unwrap();
|
||||
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
|
||||
|
||||
for i in 1..100 {
|
||||
let node = int_to_node(i);
|
||||
full.add(node);
|
||||
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
|
||||
partial.apply(delta).unwrap();
|
||||
|
||||
assert_eq!(partial.forest(), full.forest());
|
||||
assert_eq!(partial.peaks(), full.peaks());
|
||||
|
||||
let proof1 = full.open(i as usize).unwrap();
|
||||
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
|
||||
let proof2 = partial.open(proof1.position).unwrap().unwrap();
|
||||
assert_eq!(proof1.merkle_path, proof2.merkle_path);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmr_add_invalid_odd_leaf() {
|
||||
let mmr: Mmr = LEAVES.into();
|
||||
let acc = mmr.peaks();
|
||||
let mut partial: PartialMmr = acc.clone().into();
|
||||
|
||||
let empty = MerklePath::new(Vec::new());
|
||||
|
||||
// None of the other leaves should work
|
||||
for node in LEAVES.iter().cloned().rev().skip(1) {
|
||||
let result = partial.track(LEAVES.len() - 1, node, &empty);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
let result = partial.track(LEAVES.len() - 1, LEAVES[6], &empty);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
|
||||
///
|
||||
/// Here we manipulate the proof to return a peak index of 1 while the MMR only has 1 peak (with
|
||||
/// index 0).
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mmr_proof_num_peaks_exceeds_current_num_peaks() {
|
||||
let mmr: Mmr = LEAVES[0..4].iter().cloned().into();
|
||||
let mut proof = mmr.open(3).unwrap();
|
||||
proof.forest = 5;
|
||||
proof.position = 4;
|
||||
mmr.peaks().verify(LEAVES[3], proof).unwrap();
|
||||
}
|
||||
|
||||
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
|
||||
///
|
||||
/// We create an MmrProof for a leaf whose peak index to verify against is 1.
|
||||
/// Then we add another leaf which results in an Mmr with just one peak due to trees
|
||||
/// being merged. If we try to use the old proof against the new Mmr, we should get an error.
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_mmr_old_proof_num_peaks_exceeds_current_num_peaks() {
|
||||
let leaves_len = 3;
|
||||
let mut mmr = Mmr::from(LEAVES[0..leaves_len].iter().cloned());
|
||||
|
||||
let leaf_idx = leaves_len - 1;
|
||||
let proof = mmr.open(leaf_idx).unwrap();
|
||||
assert!(mmr.peaks().verify(LEAVES[leaf_idx], proof.clone()).is_ok());
|
||||
|
||||
mmr.add(LEAVES[leaves_len]);
|
||||
mmr.peaks().verify(LEAVES[leaf_idx], proof).unwrap();
|
||||
}
|
||||
|
||||
mod property_tests {
|
||||
use proptest::prelude::*;
|
||||
|
||||
use super::leaf_to_corresponding_tree;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_last_position_is_always_contained_in_the_last_tree(leaves in any::<usize>().prop_filter("cant have an empty tree", |v| *v != 0)) {
|
||||
let last_pos = leaves - 1;
|
||||
let lowest_bit = leaves.trailing_zeros();
|
||||
|
||||
assert_eq!(
|
||||
leaf_to_corresponding_tree(last_pos, leaves),
|
||||
Some(lowest_bit),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_contained_tree_is_always_power_of_two((leaves, pos) in any::<usize>().prop_flat_map(|v| (Just(v), 0..v))) {
|
||||
let tree_bit = leaf_to_corresponding_tree(pos, leaves).expect("pos is smaller than leaves, there should always be a corresponding tree");
|
||||
let mask = 1usize << tree_bit;
|
||||
|
||||
assert!(tree_bit < usize::BITS, "the result must be a bit in usize");
|
||||
assert!(mask & leaves != 0, "the result should be a tree in leaves");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
fn digests_to_elements(digests: &[RpoDigest]) -> Vec<Felt> {
|
||||
digests.iter().flat_map(Word::from).collect()
|
||||
}
|
||||
|
||||
// short hand for the rpo hash, used to make test code more concise and easy to read
|
||||
fn merge(l: RpoDigest, r: RpoDigest) -> RpoDigest {
|
||||
Rpo256::merge(&[l, r])
|
||||
}
|
62
src/merkle/mod.rs
Normal file
62
src/merkle/mod.rs
Normal file
|
@ -0,0 +1,62 @@
|
|||
//! Data structures related to Merkle trees based on RPO256 hash function.
|
||||
|
||||
use super::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
Felt, Word, EMPTY_WORD, ZERO,
|
||||
};
|
||||
|
||||
// REEXPORTS
|
||||
// ================================================================================================
|
||||
|
||||
mod empty_roots;
|
||||
pub use empty_roots::EmptySubtreeRoots;
|
||||
|
||||
mod index;
|
||||
pub use index::NodeIndex;
|
||||
|
||||
mod merkle_tree;
|
||||
pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
|
||||
|
||||
mod path;
|
||||
pub use path::{MerklePath, RootPath, ValuePath};
|
||||
|
||||
mod smt;
|
||||
#[cfg(feature = "internal")]
|
||||
pub use smt::{build_subtree_for_bench, SubtreeLeaf};
|
||||
pub use smt::{
|
||||
InnerNode, LeafIndex, MutationSet, NodeMutation, PartialSmt, SimpleSmt, Smt, SmtLeaf,
|
||||
SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
};
|
||||
|
||||
mod mmr;
|
||||
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
|
||||
|
||||
mod store;
|
||||
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
|
||||
|
||||
mod node;
|
||||
pub use node::InnerNodeInfo;
|
||||
|
||||
mod partial_mt;
|
||||
pub use partial_mt::PartialMerkleTree;
|
||||
|
||||
mod error;
|
||||
pub use error::MerkleError;
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
const fn int_to_node(value: u64) -> RpoDigest {
|
||||
RpoDigest::new([Felt::new(value), ZERO, ZERO, ZERO])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
const fn int_to_leaf(value: u64) -> Word {
|
||||
[Felt::new(value), ZERO, ZERO, ZERO]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn digests_to_words(digests: &[RpoDigest]) -> alloc::vec::Vec<Word> {
|
||||
digests.iter().map(|d| d.into()).collect()
|
||||
}
|
11
src/merkle/node.rs
Normal file
11
src/merkle/node.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
use super::RpoDigest;
|
||||
|
||||
/// Representation of a node with two children used for iterating over containers.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
#[cfg_attr(test, derive(PartialOrd, Ord))]
|
||||
pub struct InnerNodeInfo {
|
||||
pub value: RpoDigest,
|
||||
pub left: RpoDigest,
|
||||
pub right: RpoDigest,
|
||||
}
|
478
src/merkle/partial_mt/mod.rs
Normal file
478
src/merkle/partial_mt/mod.rs
Normal file
|
@ -0,0 +1,478 @@
|
|||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
string::String,
|
||||
vec::Vec,
|
||||
};
|
||||
use core::fmt;
|
||||
|
||||
use super::{
|
||||
InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Word,
|
||||
EMPTY_WORD,
|
||||
};
|
||||
use crate::utils::{
|
||||
word_to_hex, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Index of the root node.
|
||||
const ROOT_INDEX: NodeIndex = NodeIndex::root();
|
||||
|
||||
/// An RpoDigest consisting of 4 ZERO elements.
|
||||
const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD);
|
||||
|
||||
// PARTIAL MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. Partial Merkle
|
||||
/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct PartialMerkleTree {
|
||||
max_depth: u8,
|
||||
nodes: BTreeMap<NodeIndex, RpoDigest>,
|
||||
leaves: BTreeSet<NodeIndex>,
|
||||
}
|
||||
|
||||
impl Default for PartialMerkleTree {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialMerkleTree {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const MAX_DEPTH: u8 = 64;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new empty [PartialMerkleTree].
|
||||
pub fn new() -> Self {
|
||||
PartialMerkleTree {
|
||||
max_depth: 0,
|
||||
nodes: BTreeMap::new(),
|
||||
leaves: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Appends the provided paths iterator into the set.
|
||||
///
|
||||
/// Analogous to [Self::add_path].
|
||||
pub fn with_paths<I>(paths: I) -> Result<Self, MerkleError>
|
||||
where
|
||||
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
|
||||
{
|
||||
// create an empty tree
|
||||
let tree = PartialMerkleTree::new();
|
||||
|
||||
paths.into_iter().try_fold(tree, |mut tree, (index, value, path)| {
|
||||
tree.add_path(index, value, path)?;
|
||||
Ok(tree)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided
|
||||
/// entries.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain an insufficient set of nodes.
|
||||
pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
|
||||
where
|
||||
R: IntoIterator<IntoIter = I>,
|
||||
I: Iterator<Item = (NodeIndex, RpoDigest)> + ExactSizeIterator,
|
||||
{
|
||||
let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
|
||||
let mut leaves = BTreeSet::new();
|
||||
let mut nodes = BTreeMap::new();
|
||||
|
||||
// add data to the leaves and nodes maps and also fill layers map, where the key is the
|
||||
// depth of the node and value is its index.
|
||||
for (node_index, hash) in entries.into_iter() {
|
||||
leaves.insert(node_index);
|
||||
nodes.insert(node_index, hash);
|
||||
layers
|
||||
.entry(node_index.depth())
|
||||
.and_modify(|layer_vec| layer_vec.push(node_index.value()))
|
||||
.or_insert(vec![node_index.value()]);
|
||||
}
|
||||
|
||||
// check if the number of leaves can be accommodated by the tree's depth; we use a min
|
||||
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
|
||||
let max = 2usize.pow(63);
|
||||
if layers.len() > max {
|
||||
return Err(MerkleError::TooManyEntries(max));
|
||||
}
|
||||
|
||||
// Get maximum depth
|
||||
let max_depth = *layers.keys().next_back().unwrap_or(&0);
|
||||
|
||||
// fill layers without nodes with empty vector
|
||||
for depth in 0..max_depth {
|
||||
layers.entry(depth).or_default();
|
||||
}
|
||||
|
||||
let mut layer_iter = layers.into_values().rev();
|
||||
let mut parent_layer = layer_iter.next().unwrap();
|
||||
let mut current_layer;
|
||||
|
||||
for depth in (1..max_depth + 1).rev() {
|
||||
// set current_layer = parent_layer and parent_layer = layer_iter.next()
|
||||
current_layer = layer_iter.next().unwrap();
|
||||
core::mem::swap(&mut current_layer, &mut parent_layer);
|
||||
|
||||
for index_value in current_layer {
|
||||
// get the parent node index
|
||||
let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
|
||||
|
||||
// Check if the parent hash was already calculated. In about half of the cases, we
|
||||
// don't need to do anything.
|
||||
if !parent_layer.contains(&parent_node.value()) {
|
||||
// create current node index
|
||||
let index = NodeIndex::new(depth, index_value)?;
|
||||
|
||||
// get hash of the current node
|
||||
let node =
|
||||
nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
|
||||
// get hash of the sibling node
|
||||
let sibling = nodes
|
||||
.get(&index.sibling())
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
|
||||
// get parent hash
|
||||
let parent = Rpo256::merge(&index.build_node(*node, *sibling));
|
||||
|
||||
// add index value of the calculated node to the parents layer
|
||||
parent_layer.push(parent_node.value());
|
||||
// add index and hash to the nodes map
|
||||
nodes.insert(parent_node, parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PartialMerkleTree { max_depth, nodes, leaves })
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of this Merkle tree.
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST)
|
||||
}
|
||||
|
||||
/// Returns the depth of this Merkle tree.
|
||||
pub fn max_depth(&self) -> u8 {
|
||||
self.max_depth
|
||||
}
|
||||
|
||||
/// Returns a node at the specified NodeIndex.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
self.nodes
|
||||
.get(&index)
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInTree(index))
|
||||
.copied()
|
||||
}
|
||||
|
||||
/// Returns true if provided index contains in the leaves set, false otherwise.
|
||||
pub fn is_leaf(&self, index: NodeIndex) -> bool {
|
||||
self.leaves.contains(&index)
|
||||
}
|
||||
|
||||
/// Returns a vector of paths from every leaf to the root.
|
||||
pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> {
|
||||
let mut paths = Vec::new();
|
||||
self.leaves.iter().for_each(|&leaf| {
|
||||
paths.push((
|
||||
leaf,
|
||||
ValuePath {
|
||||
value: self.get_node(leaf).expect("Failed to get leaf node"),
|
||||
path: self.get_path(leaf).expect("Failed to get path"),
|
||||
},
|
||||
));
|
||||
});
|
||||
paths
|
||||
}
|
||||
|
||||
/// Returns a Merkle path from the node at the specified index to the root.
|
||||
///
|
||||
/// The node itself is not included in the path.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - the specified index has depth set to 0 or the depth is greater than the depth of this
|
||||
/// Merkle tree.
|
||||
/// - the specified index is not contained in the nodes map.
|
||||
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
|
||||
if index.is_root() {
|
||||
return Err(MerkleError::DepthTooSmall(index.depth()));
|
||||
} else if index.depth() > self.max_depth() {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
if !self.nodes.contains_key(&index) {
|
||||
return Err(MerkleError::NodeIndexNotFoundInTree(index));
|
||||
}
|
||||
|
||||
let mut path = Vec::new();
|
||||
for _ in 0..index.depth() {
|
||||
let sibling_index = index.sibling();
|
||||
index.move_up();
|
||||
let sibling =
|
||||
self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map");
|
||||
path.push(sibling);
|
||||
}
|
||||
Ok(MerklePath::new(path))
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [PartialMerkleTree].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
|
||||
self.leaves.iter().map(|&leaf| {
|
||||
(
|
||||
leaf,
|
||||
self.get_node(leaf)
|
||||
.unwrap_or_else(|_| panic!("Leaf with {leaf} is not in the nodes map")),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this Merkle tree.
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
|
||||
inner_nodes.map(|(index, digest)| {
|
||||
let left_hash =
|
||||
self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
|
||||
let right_hash =
|
||||
self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
|
||||
InnerNodeInfo {
|
||||
value: *digest,
|
||||
left: *left_hash,
|
||||
right: *right_hash,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds the nodes of the specified Merkle path to this [PartialMerkleTree]. The `index_value`
|
||||
/// and `value` parameters specify the leaf node at which the path starts.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - The depth of the specified node_index is greater than 64 or smaller than 1.
|
||||
/// - The specified path is not consistent with other paths in the set (i.e., resolves to a
|
||||
/// different root).
|
||||
pub fn add_path(
|
||||
&mut self,
|
||||
index_value: u64,
|
||||
value: RpoDigest,
|
||||
path: MerklePath,
|
||||
) -> Result<(), MerkleError> {
|
||||
let index_value = NodeIndex::new(path.len() as u8, index_value)?;
|
||||
|
||||
Self::check_depth(index_value.depth())?;
|
||||
self.update_depth(index_value.depth());
|
||||
|
||||
// add provided node and its sibling to the leaves set
|
||||
self.leaves.insert(index_value);
|
||||
let sibling_node_index = index_value.sibling();
|
||||
self.leaves.insert(sibling_node_index);
|
||||
|
||||
// add provided node and its sibling to the nodes map
|
||||
self.nodes.insert(index_value, value);
|
||||
self.nodes.insert(sibling_node_index, path[0]);
|
||||
|
||||
// traverse to the root, updating the nodes
|
||||
let mut index_value = index_value;
|
||||
let node = Rpo256::merge(&index_value.build_node(value, path[0]));
|
||||
let root = path.iter().skip(1).copied().fold(node, |node, hash| {
|
||||
index_value.move_up();
|
||||
// insert calculated node to the nodes map
|
||||
self.nodes.insert(index_value, node);
|
||||
|
||||
// if the calculated node was a leaf, remove it from leaves set.
|
||||
self.leaves.remove(&index_value);
|
||||
|
||||
let sibling_node = index_value.sibling();
|
||||
|
||||
// Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only
|
||||
// if it is a new node (it wasn't in nodes map).
|
||||
// Node can be in 3 states: internal node, leaf of the tree and not a tree node at all.
|
||||
// - Internal node can only stay in this state -- addition of a new path can't make it
|
||||
// a leaf or remove it from the tree.
|
||||
// - Leaf node can stay in the same state (remain a leaf) or can become an internal
|
||||
// node. In the first case we don't need to do anything, and the second case is handled
|
||||
// by the call of `self.leaves.remove(&index_value);`
|
||||
// - New node can be a calculated node or a "sibling" node from a Merkle Path:
|
||||
// --- Calculated node, obviously, never can be a leaf.
|
||||
// --- Sibling node can be only a leaf, because otherwise it is not a new node.
|
||||
if self.nodes.insert(sibling_node, hash).is_none() {
|
||||
self.leaves.insert(sibling_node);
|
||||
}
|
||||
|
||||
Rpo256::merge(&index_value.build_node(node, hash))
|
||||
});
|
||||
|
||||
// if the path set is empty (the root is all ZEROs), set the root to the root of the added
|
||||
// path; otherwise, the root of the added path must be identical to the current root
|
||||
if self.root() == EMPTY_DIGEST {
|
||||
self.nodes.insert(ROOT_INDEX, root);
|
||||
} else if self.root() != root {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: self.root(),
|
||||
actual_root: root,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Updates value of the leaf at the specified index returning the old leaf value.
|
||||
///
|
||||
/// By default the specified index is assumed to belong to the deepest layer. If the considered
|
||||
/// node does not belong to the tree, the first node on the way to the root will be changed.
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - No entry exists at the specified index.
|
||||
/// - The specified index is greater than the maximum number of nodes on the deepest layer.
|
||||
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
|
||||
let mut node_index = NodeIndex::new(self.max_depth(), index)?;
|
||||
|
||||
// proceed to the leaf
|
||||
for _ in 0..node_index.depth() {
|
||||
if !self.leaves.contains(&node_index) {
|
||||
node_index.move_up();
|
||||
}
|
||||
}
|
||||
|
||||
// add node value to the nodes Map
|
||||
let old_value = self
|
||||
.nodes
|
||||
.insert(node_index, value.into())
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == *old_value {
|
||||
return Ok(old_value);
|
||||
}
|
||||
|
||||
let mut value = value.into();
|
||||
for _ in 0..node_index.depth() {
|
||||
let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
|
||||
value = Rpo256::merge(&node_index.build_node(value, *sibling));
|
||||
node_index.move_up();
|
||||
self.nodes.insert(node_index, value);
|
||||
}
|
||||
|
||||
Ok(old_value)
|
||||
}
|
||||
|
||||
// UTILITY FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Utility to visualize a [PartialMerkleTree] in text.
|
||||
pub fn print(&self) -> Result<String, fmt::Error> {
|
||||
let indent = " ";
|
||||
let mut s = String::new();
|
||||
s.push_str("root: ");
|
||||
s.push_str(&word_to_hex(&self.root())?);
|
||||
s.push('\n');
|
||||
for d in 1..=self.max_depth() {
|
||||
let entries = 2u64.pow(d.into());
|
||||
for i in 0..entries {
|
||||
let index = NodeIndex::new(d, i).expect("The index must always be valid");
|
||||
let node = self.get_node(index);
|
||||
let node = match node {
|
||||
Err(_) => continue,
|
||||
Ok(node) => node,
|
||||
};
|
||||
|
||||
for _ in 0..d {
|
||||
s.push_str(indent);
|
||||
}
|
||||
s.push_str(&format!("({}, {}): ", index.depth(), index.value()));
|
||||
s.push_str(&word_to_hex(&node)?);
|
||||
s.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Updates depth value with the maximum of current and provided depth.
|
||||
fn update_depth(&mut self, new_depth: u8) {
|
||||
self.max_depth = new_depth.max(self.max_depth);
|
||||
}
|
||||
|
||||
/// Returns an error if the depth is 0 or is greater than 64.
|
||||
fn check_depth(depth: u8) -> Result<(), MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if depth < Self::MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(depth));
|
||||
} else if Self::MAX_DEPTH < depth {
|
||||
return Err(MerkleError::DepthTooBig(depth as u64));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for PartialMerkleTree {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// write leaf nodes
|
||||
target.write_u64(self.leaves.len() as u64);
|
||||
for leaf_index in self.leaves.iter() {
|
||||
leaf_index.write_into(target);
|
||||
self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for PartialMerkleTree {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let leaves_len = source.read_u64()? as usize;
|
||||
let mut leaf_nodes = Vec::with_capacity(leaves_len);
|
||||
|
||||
// add leaf nodes to the vector
|
||||
for _ in 0..leaves_len {
|
||||
let index = NodeIndex::read_from(source)?;
|
||||
let hash = RpoDigest::read_from(source)?;
|
||||
leaf_nodes.push((index, hash));
|
||||
}
|
||||
|
||||
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
|
||||
DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
|
||||
})?;
|
||||
|
||||
Ok(pmt)
|
||||
}
|
||||
}
|
466
src/merkle/partial_mt/tests.rs
Normal file
466
src/merkle/partial_mt/tests.rs
Normal file
|
@ -0,0 +1,466 @@
|
|||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
|
||||
use super::{
|
||||
super::{
|
||||
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
|
||||
PartialMerkleTree,
|
||||
},
|
||||
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const NODE10: NodeIndex = NodeIndex::new_unchecked(1, 0);
|
||||
const NODE11: NodeIndex = NodeIndex::new_unchecked(1, 1);
|
||||
|
||||
const NODE20: NodeIndex = NodeIndex::new_unchecked(2, 0);
|
||||
const NODE21: NodeIndex = NodeIndex::new_unchecked(2, 1);
|
||||
const NODE22: NodeIndex = NodeIndex::new_unchecked(2, 2);
|
||||
const NODE23: NodeIndex = NodeIndex::new_unchecked(2, 3);
|
||||
|
||||
const NODE30: NodeIndex = NodeIndex::new_unchecked(3, 0);
|
||||
const NODE31: NodeIndex = NodeIndex::new_unchecked(3, 1);
|
||||
const NODE32: NodeIndex = NodeIndex::new_unchecked(3, 2);
|
||||
const NODE33: NodeIndex = NodeIndex::new_unchecked(3, 3);
|
||||
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(30),
|
||||
int_to_node(31),
|
||||
int_to_node(32),
|
||||
int_to_node(33),
|
||||
int_to_node(34),
|
||||
int_to_node(35),
|
||||
int_to_node(36),
|
||||
int_to_node(37),
|
||||
];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
// For the Partial Merkle Tree tests we will use parts of the Merkle Tree which full form is
|
||||
// illustrated below:
|
||||
//
|
||||
// __________ root __________
|
||||
// / \
|
||||
// ____ 10 ____ ____ 11 ____
|
||||
// / \ / \
|
||||
// 20 21 22 23
|
||||
// / \ / \ / \ / \
|
||||
// (30) (31) (32) (33) (34) (35) (36) (37)
|
||||
//
|
||||
// Where node number is a concatenation of its depth and index. For example, node with
|
||||
// NodeIndex(3, 5) will be labeled as `35`. Leaves of the tree are shown as nodes with parenthesis
|
||||
// (33).
|
||||
|
||||
/// Checks that creation of the PMT with `with_leaves()` constructor is working correctly.
|
||||
#[test]
|
||||
fn with_leaves() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let leaf_nodes_vec = vec![
|
||||
(NODE20, mt.get_node(NODE20).unwrap()),
|
||||
(NODE32, mt.get_node(NODE32).unwrap()),
|
||||
(NODE33, mt.get_node(NODE33).unwrap()),
|
||||
(NODE22, mt.get_node(NODE22).unwrap()),
|
||||
(NODE23, mt.get_node(NODE23).unwrap()),
|
||||
];
|
||||
|
||||
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
|
||||
|
||||
let pmt = PartialMerkleTree::with_leaves(leaf_nodes).unwrap();
|
||||
|
||||
assert_eq!(expected_root, pmt.root())
|
||||
}
|
||||
|
||||
/// Checks that `with_leaves()` function returns an error when using incomplete set of nodes.
|
||||
#[test]
|
||||
fn err_with_leaves() {
|
||||
// NODE22 is missing
|
||||
let leaf_nodes_vec = vec![
|
||||
(NODE20, int_to_node(20)),
|
||||
(NODE32, int_to_node(32)),
|
||||
(NODE33, int_to_node(33)),
|
||||
(NODE23, int_to_node(23)),
|
||||
];
|
||||
|
||||
let leaf_nodes: BTreeMap<NodeIndex, RpoDigest> = leaf_nodes_vec.into_iter().collect();
|
||||
|
||||
assert!(PartialMerkleTree::with_leaves(leaf_nodes).is_err());
|
||||
}
|
||||
|
||||
/// Checks that root returned by `root()` function is equal to the expected one.
|
||||
#[test]
|
||||
fn get_root() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert_eq!(expected_root, pmt.root());
|
||||
}
|
||||
|
||||
/// This test checks correctness of the `add_path()` and `get_path()` functions. First it creates a
|
||||
/// PMT using `add_path()` by adding Merkle Paths from node 33 and node 22 to the empty PMT. Then
|
||||
/// it checks that paths returned by `get_path()` function are equal to the expected ones.
|
||||
#[test]
|
||||
fn add_and_get_paths() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let expected_path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let expected_path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::new();
|
||||
pmt.add_path(3, expected_path33.value, expected_path33.path.clone()).unwrap();
|
||||
pmt.add_path(2, expected_path22.value, expected_path22.path.clone()).unwrap();
|
||||
|
||||
let path33 = pmt.get_path(NODE33).unwrap();
|
||||
let path22 = pmt.get_path(NODE22).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_path33.path, path33);
|
||||
assert_eq!(expected_path22.path, path22);
|
||||
assert_eq!(expected_root, actual_root);
|
||||
}
|
||||
|
||||
/// Checks that function `get_node` used on nodes 10 and 32 returns expected values.
|
||||
#[test]
|
||||
fn get_node() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert_eq!(ms.get_node(expected_root, NODE32).unwrap(), pmt.get_node(NODE32).unwrap());
|
||||
assert_eq!(ms.get_node(expected_root, NODE10).unwrap(), pmt.get_node(NODE10).unwrap());
|
||||
}
|
||||
|
||||
/// Updates leaves of the PMT using `update_leaf()` function and checks that new root of the tree
|
||||
/// is equal to the expected one.
|
||||
#[test]
|
||||
fn update_leaf() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let root = mt.root();
|
||||
|
||||
let mut ms = MerkleStore::from(&mt);
|
||||
let path33 = ms.get_path(root, NODE33).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
let new_value32 = int_to_node(132);
|
||||
let expected_root = ms.set_node(root, NODE32, new_value32).unwrap().root;
|
||||
|
||||
pmt.update_leaf(2, *new_value32).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_root, actual_root);
|
||||
|
||||
let new_value20 = int_to_node(120);
|
||||
let expected_root = ms.set_node(expected_root, NODE20, new_value20).unwrap().root;
|
||||
|
||||
pmt.update_leaf(0, *new_value20).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_root, actual_root);
|
||||
|
||||
let new_value11 = int_to_node(111);
|
||||
let expected_root = ms.set_node(expected_root, NODE11, new_value11).unwrap().root;
|
||||
|
||||
pmt.update_leaf(6, *new_value11).unwrap();
|
||||
let actual_root = pmt.root();
|
||||
|
||||
assert_eq!(expected_root, actual_root);
|
||||
}
|
||||
|
||||
/// Checks that paths of the PMT returned by `paths()` function are equal to the expected ones.
|
||||
#[test]
|
||||
fn get_paths() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::new();
|
||||
pmt.add_path(3, path33.value, path33.path).unwrap();
|
||||
pmt.add_path(2, path22.value, path22.path).unwrap();
|
||||
// After PMT creation with path33 (33; 32, 20, 11) and path22 (22; 23, 10) we will have this
|
||||
// tree:
|
||||
//
|
||||
// ______root______
|
||||
// / \
|
||||
// ___10___ ___11___
|
||||
// / \ / \
|
||||
// (20) 21 (22) (23)
|
||||
// / \
|
||||
// (32) (33)
|
||||
//
|
||||
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
|
||||
// for each leaf.
|
||||
|
||||
let leaves = [NODE20, NODE22, NODE23, NODE32, NODE33];
|
||||
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
|
||||
.iter()
|
||||
.map(|&leaf| {
|
||||
(
|
||||
leaf,
|
||||
ValuePath {
|
||||
value: mt.get_node(leaf).unwrap(),
|
||||
path: mt.get_path(leaf).unwrap(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let actual_paths = pmt.to_paths();
|
||||
|
||||
assert_eq!(expected_paths, actual_paths);
|
||||
}
|
||||
|
||||
// Checks correctness of leaves determination when using the `leaves()` function.
|
||||
#[test]
|
||||
fn leaves() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
// After PMT creation with path33 (33; 32, 20, 11) we will have this tree:
|
||||
//
|
||||
// ______root______
|
||||
// / \
|
||||
// ___10___ (11)
|
||||
// / \
|
||||
// (20) 21
|
||||
// / \
|
||||
// (32) (33)
|
||||
//
|
||||
// Which have leaf nodes 11, 20, 32 and 33.
|
||||
|
||||
let value11 = mt.get_node(NODE11).unwrap();
|
||||
let value20 = mt.get_node(NODE20).unwrap();
|
||||
let value32 = mt.get_node(NODE32).unwrap();
|
||||
let value33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
let leaves = [(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
|
||||
|
||||
let expected_leaves = leaves.iter().copied();
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
|
||||
pmt.add_path(2, path22.value, path22.path).unwrap();
|
||||
// After adding the path22 (22; 23, 10) to the existing PMT we will have this tree:
|
||||
//
|
||||
// ______root______
|
||||
// / \
|
||||
// ___10___ ___11___
|
||||
// / \ / \
|
||||
// (20) 21 (22) (23)
|
||||
// / \
|
||||
// (32) (33)
|
||||
//
|
||||
// Which have leaf nodes 20, 22, 23, 32 and 33.
|
||||
|
||||
let value20 = mt.get_node(NODE20).unwrap();
|
||||
let value22 = mt.get_node(NODE22).unwrap();
|
||||
let value23 = mt.get_node(NODE23).unwrap();
|
||||
let value32 = mt.get_node(NODE32).unwrap();
|
||||
let value33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
let leaves = vec![
|
||||
(NODE20, value20),
|
||||
(NODE22, value22),
|
||||
(NODE23, value23),
|
||||
(NODE32, value32),
|
||||
(NODE33, value33),
|
||||
];
|
||||
|
||||
let expected_leaves = leaves.iter().copied();
|
||||
assert!(expected_leaves.eq(pmt.leaves()));
|
||||
}
|
||||
|
||||
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected
|
||||
/// ones.
|
||||
#[test]
|
||||
fn test_inner_node_iterator() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
// get actual inner nodes
|
||||
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
|
||||
|
||||
let expected_n00 = mt.root();
|
||||
let expected_n10 = mt.get_node(NODE10).unwrap();
|
||||
let expected_n11 = mt.get_node(NODE11).unwrap();
|
||||
let expected_n20 = mt.get_node(NODE20).unwrap();
|
||||
let expected_n21 = mt.get_node(NODE21).unwrap();
|
||||
let expected_n32 = mt.get_node(NODE32).unwrap();
|
||||
let expected_n33 = mt.get_node(NODE33).unwrap();
|
||||
|
||||
// create vector of the expected inner nodes
|
||||
let mut expected = vec![
|
||||
InnerNodeInfo {
|
||||
value: expected_n00,
|
||||
left: expected_n10,
|
||||
right: expected_n11,
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: expected_n10,
|
||||
left: expected_n20,
|
||||
right: expected_n21,
|
||||
},
|
||||
InnerNodeInfo {
|
||||
value: expected_n21,
|
||||
left: expected_n32,
|
||||
right: expected_n33,
|
||||
},
|
||||
];
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
// add another path to the Partial Merkle Tree
|
||||
pmt.add_path(2, path22.value, path22.path).unwrap();
|
||||
|
||||
// get new actual inner nodes
|
||||
let actual: Vec<InnerNodeInfo> = pmt.inner_nodes().collect();
|
||||
|
||||
let expected_n22 = mt.get_node(NODE22).unwrap();
|
||||
let expected_n23 = mt.get_node(NODE23).unwrap();
|
||||
|
||||
let info_11 = InnerNodeInfo {
|
||||
value: expected_n11,
|
||||
left: expected_n22,
|
||||
right: expected_n23,
|
||||
};
|
||||
|
||||
// add new inner node to the existing vertor
|
||||
expected.insert(2, info_11);
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Checks that serialization and deserialization implementations for the PMT are working
|
||||
/// correctly.
|
||||
#[test]
|
||||
fn serialization() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
let path22 = ms.get_path(expected_root, NODE22).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([
|
||||
(3, path33.value, path33.path),
|
||||
(2, path22.value, path22.path),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let serialized_pmt = pmt.to_bytes();
|
||||
let deserialized_pmt = PartialMerkleTree::read_from_bytes(&serialized_pmt).unwrap();
|
||||
|
||||
assert_eq!(deserialized_pmt, pmt);
|
||||
}
|
||||
|
||||
/// Checks that deserialization fails with incorrect data.
|
||||
#[test]
|
||||
fn err_deserialization() {
|
||||
let mut tree_bytes: Vec<u8> = vec![5];
|
||||
tree_bytes.append(&mut NODE20.to_bytes());
|
||||
tree_bytes.append(&mut int_to_node(20).to_bytes());
|
||||
|
||||
tree_bytes.append(&mut NODE21.to_bytes());
|
||||
tree_bytes.append(&mut int_to_node(21).to_bytes());
|
||||
|
||||
// node with depth 1 could have index 0 or 1, but it has 2
|
||||
tree_bytes.append(&mut vec![1, 2]);
|
||||
tree_bytes.append(&mut int_to_node(11).to_bytes());
|
||||
|
||||
assert!(PartialMerkleTree::read_from_bytes(&tree_bytes).is_err());
|
||||
}
|
||||
|
||||
/// Checks that addition of the path with different root will cause an error.
|
||||
#[test]
|
||||
fn err_add_path() {
|
||||
let path33 = vec![int_to_node(1), int_to_node(2), int_to_node(3)].into();
|
||||
let path22 = vec![int_to_node(4), int_to_node(5)].into();
|
||||
|
||||
let mut pmt = PartialMerkleTree::new();
|
||||
pmt.add_path(3, int_to_node(6), path33).unwrap();
|
||||
|
||||
assert!(pmt.add_path(2, int_to_node(7), path22).is_err());
|
||||
}
|
||||
|
||||
/// Checks that the request of the node which is not in the PMT will cause an error.
|
||||
#[test]
|
||||
fn err_get_node() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert!(pmt.get_node(NODE22).is_err());
|
||||
assert!(pmt.get_node(NODE23).is_err());
|
||||
assert!(pmt.get_node(NODE30).is_err());
|
||||
assert!(pmt.get_node(NODE31).is_err());
|
||||
}
|
||||
|
||||
/// Checks that the request of the path from the leaf which is not in the PMT will cause an error.
|
||||
#[test]
|
||||
fn err_get_path() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert!(pmt.get_path(NODE22).is_err());
|
||||
assert!(pmt.get_path(NODE23).is_err());
|
||||
assert!(pmt.get_path(NODE30).is_err());
|
||||
assert!(pmt.get_path(NODE31).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn err_update_leaf() {
|
||||
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let expected_root = mt.root();
|
||||
|
||||
let ms = MerkleStore::from(&mt);
|
||||
|
||||
let path33 = ms.get_path(expected_root, NODE33).unwrap();
|
||||
|
||||
let mut pmt = PartialMerkleTree::with_paths([(3, path33.value, path33.path)]).unwrap();
|
||||
|
||||
assert!(pmt.update_leaf(8, *int_to_node(38)).is_err());
|
||||
}
|
282
src/merkle/path.rs
Normal file
282
src/merkle/path.rs
Normal file
|
@ -0,0 +1,282 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::{Deref, DerefMut};
|
||||
|
||||
use super::{InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest};
|
||||
use crate::{
|
||||
utils::{ByteReader, Deserializable, DeserializationError, Serializable},
|
||||
Word,
|
||||
};
|
||||
|
||||
// MERKLE PATH
|
||||
// ================================================================================================
|
||||
|
||||
/// A merkle path container, composed of a sequence of nodes of a Merkle tree.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerklePath {
|
||||
nodes: Vec<RpoDigest>,
|
||||
}
|
||||
|
||||
impl MerklePath {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates a new Merkle path from a list of nodes.
|
||||
pub fn new(nodes: Vec<RpoDigest>) -> Self {
|
||||
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
|
||||
Self { nodes }
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth in which this Merkle path proof is valid.
|
||||
pub fn depth(&self) -> u8 {
|
||||
self.nodes.len() as u8
|
||||
}
|
||||
|
||||
/// Returns a reference to the [MerklePath]'s nodes.
|
||||
pub fn nodes(&self) -> &[RpoDigest] {
|
||||
&self.nodes
|
||||
}
|
||||
|
||||
/// Computes the merkle root for this opening.
|
||||
pub fn compute_root(&self, index: u64, node: RpoDigest) -> Result<RpoDigest, MerkleError> {
|
||||
let mut index = NodeIndex::new(self.depth(), index)?;
|
||||
let root = self.nodes.iter().copied().fold(node, |node, sibling| {
|
||||
// compute the node and move to the next iteration.
|
||||
let input = index.build_node(node, sibling);
|
||||
index.move_up();
|
||||
Rpo256::merge(&input)
|
||||
});
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
/// Verifies the Merkle opening proof towards the provided root.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - provided node index is invalid.
|
||||
/// - root calculated during the verification differs from the provided one.
|
||||
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> Result<(), MerkleError> {
|
||||
let computed_root = self.compute_root(index, node)?;
|
||||
if &computed_root != root {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: *root,
|
||||
actual_root: computed_root,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over every inner node of this [MerklePath].
|
||||
///
|
||||
/// The iteration order is unspecified.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index is not valid for this path.
|
||||
pub fn inner_nodes(
|
||||
&self,
|
||||
index: u64,
|
||||
node: RpoDigest,
|
||||
) -> Result<InnerNodeIterator, MerkleError> {
|
||||
Ok(InnerNodeIterator {
|
||||
nodes: &self.nodes,
|
||||
index: NodeIndex::new(self.depth(), index)?,
|
||||
value: node,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<MerklePath> for Vec<RpoDigest> {
|
||||
fn from(path: MerklePath) -> Self {
|
||||
path.nodes
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<RpoDigest>> for MerklePath {
|
||||
fn from(path: Vec<RpoDigest>) -> Self {
|
||||
Self::new(path)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[RpoDigest]> for MerklePath {
|
||||
fn from(path: &[RpoDigest]) -> Self {
|
||||
Self::new(path.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MerklePath {
|
||||
// we use `Vec` here instead of slice so we can call vector mutation methods directly from the
|
||||
// merkle path (example: `Vec::remove`).
|
||||
type Target = Vec<RpoDigest>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.nodes
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for MerklePath {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.nodes
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
impl FromIterator<RpoDigest> for MerklePath {
|
||||
fn from_iter<T: IntoIterator<Item = RpoDigest>>(iter: T) -> Self {
|
||||
Self::new(iter.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for MerklePath {
|
||||
type Item = RpoDigest;
|
||||
type IntoIter = alloc::vec::IntoIter<RpoDigest>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.nodes.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over internal nodes of a [MerklePath].
|
||||
pub struct InnerNodeIterator<'a> {
|
||||
nodes: &'a Vec<RpoDigest>,
|
||||
index: NodeIndex,
|
||||
value: RpoDigest,
|
||||
}
|
||||
|
||||
impl Iterator for InnerNodeIterator<'_> {
|
||||
type Item = InnerNodeInfo;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if !self.index.is_root() {
|
||||
let sibling_pos = self.nodes.len() - self.index.depth() as usize;
|
||||
let (left, right) = if self.index.is_value_odd() {
|
||||
(self.nodes[sibling_pos], self.value)
|
||||
} else {
|
||||
(self.value, self.nodes[sibling_pos])
|
||||
};
|
||||
|
||||
self.value = Rpo256::merge(&[left, right]);
|
||||
self.index.move_up();
|
||||
|
||||
Some(InnerNodeInfo { value: self.value, left, right })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MERKLE PATH CONTAINERS
|
||||
// ================================================================================================
|
||||
|
||||
/// A container for a [crate::Word] value and its [MerklePath] opening.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct ValuePath {
|
||||
/// The node value opening for `path`.
|
||||
pub value: RpoDigest,
|
||||
/// The path from `value` to `root` (exclusive).
|
||||
pub path: MerklePath,
|
||||
}
|
||||
|
||||
impl ValuePath {
|
||||
/// Returns a new [ValuePath] instantiated from the specified value and path.
|
||||
pub fn new(value: RpoDigest, path: MerklePath) -> Self {
|
||||
Self { value, path }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(MerklePath, Word)> for ValuePath {
|
||||
fn from((path, value): (MerklePath, Word)) -> Self {
|
||||
ValuePath::new(value.into(), path)
|
||||
}
|
||||
}
|
||||
|
||||
/// A container for a [MerklePath] and its [crate::Word] root.
|
||||
///
|
||||
/// This structure does not provide any guarantees regarding the correctness of the path to the
|
||||
/// root. For more information, check [MerklePath::verify].
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct RootPath {
|
||||
/// The node value opening for `path`.
|
||||
pub root: RpoDigest,
|
||||
/// The path from `value` to `root` (exclusive).
|
||||
pub path: MerklePath,
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for MerklePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
|
||||
target.write_u8(self.nodes.len() as u8);
|
||||
target.write_many(&self.nodes);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for MerklePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let count = source.read_u8()?.into();
|
||||
let nodes = source.read_many::<RpoDigest>(count)?;
|
||||
Ok(Self { nodes })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for ValuePath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.value.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for ValuePath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let value = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { value, path })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for RootPath {
|
||||
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
|
||||
self.root.write_into(target);
|
||||
self.path.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RootPath {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let root = RpoDigest::read_from(source)?;
|
||||
let path = MerklePath::read_from(source)?;
|
||||
Ok(Self { root, path })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::merkle::{int_to_node, MerklePath};
|
||||
|
||||
#[test]
|
||||
fn test_inner_nodes() {
|
||||
let nodes = vec![int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
let merkle_path = MerklePath::new(nodes);
|
||||
|
||||
let index = 6;
|
||||
let node = int_to_node(5);
|
||||
let root = merkle_path.compute_root(index, node).unwrap();
|
||||
|
||||
let inner_root = merkle_path.inner_nodes(index, node).unwrap().last().unwrap().value;
|
||||
|
||||
assert_eq!(root, inner_root);
|
||||
}
|
||||
}
|
605
src/merkle/smt/full/concurrent/mod.rs
Normal file
605
src/merkle/smt/full/concurrent/mod.rs
Normal file
|
@ -0,0 +1,605 @@
|
|||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
use core::mem;
|
||||
|
||||
use num::Integer;
|
||||
|
||||
use super::{
|
||||
EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet,
|
||||
NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
|
||||
};
|
||||
use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
type MutatedSubtreeLeaves = Vec<Vec<SubtreeLeaf>>;
|
||||
|
||||
// CONCURRENT IMPLEMENTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl Smt {
|
||||
/// Parallel implementation of [`Smt::with_entries()`].
|
||||
///
|
||||
/// This method constructs a new sparse Merkle tree concurrently by processing subtrees in
|
||||
/// parallel, working from the bottom up. The process works as follows:
|
||||
///
|
||||
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
|
||||
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
|
||||
///
|
||||
/// 2. The subtrees are then processed in parallel:
|
||||
/// - For each subtree, compute the inner nodes from depth D down to depth D-8.
|
||||
/// - Each subtree computation yields a new subtree root and its associated inner nodes.
|
||||
///
|
||||
/// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration,
|
||||
/// which processes the next 8 levels up. This continues until the final root of the tree is
|
||||
/// computed at depth 0.
|
||||
pub(crate) fn with_entries_concurrent(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
let mut seen_keys = BTreeSet::new();
|
||||
let entries: Vec<_> = entries
|
||||
.into_iter()
|
||||
.map(|(key, value)| {
|
||||
if seen_keys.insert(key) {
|
||||
Ok((key, value))
|
||||
} else {
|
||||
Err(MerkleError::DuplicateValuesForIndex(
|
||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
if entries.is_empty() {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
let (inner_nodes, leaves) = Self::build_subtrees(entries);
|
||||
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
|
||||
}
|
||||
|
||||
/// Parallel implementation of [`Smt::compute_mutations()`].
|
||||
///
|
||||
/// This method computes mutations by recursively processing subtrees in parallel, working from
|
||||
/// the bottom up. The process works as follows:
|
||||
///
|
||||
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
|
||||
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
|
||||
///
|
||||
/// 2. The subtrees containing modifications are then processed in parallel:
|
||||
/// - For each modified subtree, compute node mutations from depth D up to depth D-8
|
||||
/// - Each subtree computation yields a new root at depth D-8 and its associated mutations
|
||||
///
|
||||
/// 3. These subtree roots become the "leaves" for the next iteration, which processes the next
|
||||
/// 8 levels up. This continues until reaching the tree's root at depth 0.
|
||||
pub(crate) fn compute_mutations_concurrent(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word>
|
||||
where
|
||||
Self: Sized + Sync,
|
||||
{
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Collect and sort key-value pairs by their corresponding leaf index
|
||||
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
|
||||
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());
|
||||
|
||||
// Convert sorted pairs into mutated leaves and capture any new pairs
|
||||
let (mut subtree_leaves, new_pairs) =
|
||||
self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs);
|
||||
let mut node_mutations = NodeMutations::default();
|
||||
|
||||
// Process each depth level in reverse, stepping by the subtree depth
|
||||
for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
// Parallel processing of each subtree to generate mutations and roots
|
||||
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
|
||||
.into_par_iter()
|
||||
.map(|subtree| {
|
||||
debug_assert!(subtree.is_sorted() && !subtree.is_empty());
|
||||
self.build_subtree_mutations(subtree, SMT_DEPTH, depth)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
// Prepare leaves for the next depth level
|
||||
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
|
||||
// Aggregate all node mutations
|
||||
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
|
||||
|
||||
debug_assert!(!subtree_leaves.is_empty());
|
||||
}
|
||||
|
||||
// Finalize the mutation set with updated roots and mutations
|
||||
MutationSet {
|
||||
old_root: self.root(),
|
||||
new_root: subtree_leaves[0][0].hash,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
}
|
||||
}
|
||||
|
||||
// SUBTREE MUTATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Computes the node mutations and the root of a subtree
|
||||
fn build_subtree_mutations(
|
||||
&self,
|
||||
mut leaves: Vec<SubtreeLeaf>,
|
||||
tree_depth: u8,
|
||||
bottom_depth: u8,
|
||||
) -> (NodeMutations, SubtreeLeaf)
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
debug_assert!(bottom_depth <= tree_depth);
|
||||
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
|
||||
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
|
||||
|
||||
let subtree_root_depth = bottom_depth - SUBTREE_DEPTH;
|
||||
let mut node_mutations: NodeMutations = Default::default();
|
||||
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
|
||||
|
||||
for current_depth in (subtree_root_depth..bottom_depth).rev() {
|
||||
debug_assert!(current_depth <= bottom_depth);
|
||||
|
||||
let next_depth = current_depth + 1;
|
||||
let mut iter = leaves.drain(..).peekable();
|
||||
|
||||
while let Some(first_leaf) = iter.next() {
|
||||
// This constructs a valid index because next_depth will never exceed the depth of
|
||||
// the tree.
|
||||
let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent();
|
||||
let parent_node = self.get_inner_node(parent_index);
|
||||
let combined_node = fetch_sibling_pair(&mut iter, first_leaf, parent_node);
|
||||
let combined_hash = combined_node.hash();
|
||||
|
||||
let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth);
|
||||
|
||||
// Add the parent node even if it is empty for proper upward updates
|
||||
next_leaves.push(SubtreeLeaf {
|
||||
col: parent_index.value(),
|
||||
hash: combined_hash,
|
||||
});
|
||||
|
||||
node_mutations.insert(
|
||||
parent_index,
|
||||
if combined_hash != empty_hash {
|
||||
NodeMutation::Addition(combined_node)
|
||||
} else {
|
||||
NodeMutation::Removal
|
||||
},
|
||||
);
|
||||
}
|
||||
drop(iter);
|
||||
leaves = mem::take(&mut next_leaves);
|
||||
}
|
||||
|
||||
debug_assert_eq!(leaves.len(), 1);
|
||||
let root_leaf = leaves.pop().unwrap();
|
||||
(node_mutations, root_leaf)
|
||||
}
|
||||
|
||||
// SUBTREE CONSTRUCTION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||
///
|
||||
/// `entries` need not be sorted. This function will sort them.
|
||||
fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
|
||||
entries.sort_by_key(|item| {
|
||||
let index = Self::key_to_leaf_index(&item.0);
|
||||
index.value()
|
||||
});
|
||||
Self::build_subtrees_from_sorted_entries(entries)
|
||||
}
|
||||
|
||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||
///
|
||||
/// This function is mostly an implementation detail of
|
||||
/// [`Smt::with_entries_concurrent()`].
|
||||
fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let mut accumulated_nodes: InnerNodes = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: initial_leaves,
|
||||
} = Self::sorted_pairs_to_leaves(entries);
|
||||
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) =
|
||||
leaf_subtrees
|
||||
.into_par_iter()
|
||||
.map(|subtree| {
|
||||
debug_assert!(subtree.is_sorted());
|
||||
debug_assert!(!subtree.is_empty());
|
||||
let (nodes, subtree_root) =
|
||||
build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||
(nodes, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
|
||||
debug_assert!(!leaf_subtrees.is_empty());
|
||||
}
|
||||
(accumulated_nodes, initial_leaves)
|
||||
}
|
||||
|
||||
// LEAF NODE CONSTRUCTION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
|
||||
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
|
||||
/// the inputs to feed into [`build_subtree()`].
|
||||
///
|
||||
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
|
||||
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
|
||||
/// sorted. Without debug assertions, the returned computations will be incorrect.
|
||||
fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations<u64, SmtLeaf> {
|
||||
Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
|
||||
}
|
||||
|
||||
/// Computes leaves from a set of key-value pairs and current leaf values.
|
||||
/// Derived from `sorted_pairs_to_leaves`
|
||||
fn sorted_pairs_to_mutated_subtree_leaves(
|
||||
&self,
|
||||
pairs: Vec<(RpoDigest, Word)>,
|
||||
) -> (MutatedSubtreeLeaves, UnorderedMap<RpoDigest, Word>) {
|
||||
// Map to track new key-value pairs for mutated leaves
|
||||
let mut new_pairs = UnorderedMap::new();
|
||||
|
||||
let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
|
||||
let mut leaf = self.get_leaf(&leaf_pairs[0].0);
|
||||
|
||||
for (key, value) in leaf_pairs {
|
||||
// Check if the value has changed
|
||||
let old_value =
|
||||
new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
|
||||
|
||||
// Skip if the value hasn't changed
|
||||
if value == old_value {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, update the leaf and track the new key-value pair
|
||||
leaf = self.construct_prospective_leaf(leaf, &key, &value);
|
||||
new_pairs.insert(key, value);
|
||||
}
|
||||
|
||||
leaf
|
||||
});
|
||||
(accumulator.leaves, new_pairs)
|
||||
}
|
||||
|
||||
/// Processes sorted key-value pairs to compute leaves for a subtree.
|
||||
///
|
||||
/// This function groups key-value pairs by their corresponding column index and processes each
|
||||
/// group to construct leaves. The actual construction of the leaf is delegated to the
|
||||
/// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating
|
||||
/// new leaves or mutating existing ones).
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index
|
||||
/// column (not simply by key). If the input is not sorted correctly, the function will
|
||||
/// produce incorrect results and may panic in debug mode.
|
||||
/// - `process_leaf`: A callback function used to process each group of key-value pairs
|
||||
/// corresponding to the same column index. The callback takes a vector of key-value pairs for
|
||||
/// a single column and returns the constructed leaf for that column.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `PairComputations<u64, Self::Leaf>` containing:
|
||||
/// - `nodes`: A mapping of column indices to the constructed leaves.
|
||||
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
|
||||
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
|
||||
///
|
||||
/// # Panics
|
||||
/// This function will panic in debug mode if the input `pairs` are not sorted by column index.
|
||||
fn process_sorted_pairs_to_leaves<F>(
|
||||
pairs: Vec<(RpoDigest, Word)>,
|
||||
mut process_leaf: F,
|
||||
) -> PairComputations<u64, SmtLeaf>
|
||||
where
|
||||
F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf,
|
||||
{
|
||||
use rayon::prelude::*;
|
||||
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
|
||||
|
||||
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
|
||||
|
||||
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
|
||||
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
|
||||
// out and store them in our accumulated leaves.
|
||||
let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default();
|
||||
|
||||
let mut iter = pairs.into_iter().peekable();
|
||||
while let Some((key, value)) = iter.next() {
|
||||
let col = Self::key_to_leaf_index(&key).index.value();
|
||||
let peeked_col = iter.peek().map(|(key, _v)| {
|
||||
let index = Self::key_to_leaf_index(key);
|
||||
let next_col = index.index.value();
|
||||
// We panic if `pairs` is not sorted by column.
|
||||
debug_assert!(next_col >= col);
|
||||
next_col
|
||||
});
|
||||
current_leaf_buffer.push((key, value));
|
||||
|
||||
// If the next pair is the same column as this one, then we're done after adding this
|
||||
// pair to the buffer.
|
||||
if peeked_col == Some(col) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, the next pair is a different column, or there is no next pair. Either way
|
||||
// it's time to swap out our buffer.
|
||||
let leaf_pairs = mem::take(&mut current_leaf_buffer);
|
||||
let leaf = process_leaf(leaf_pairs);
|
||||
|
||||
accumulator.nodes.insert(col, leaf);
|
||||
|
||||
debug_assert!(current_leaf_buffer.is_empty());
|
||||
}
|
||||
|
||||
// Compute the leaves from the nodes concurrently
|
||||
let mut accumulated_leaves: Vec<SubtreeLeaf> = accumulator
|
||||
.nodes
|
||||
.clone()
|
||||
.into_par_iter()
|
||||
.map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) })
|
||||
.collect();
|
||||
|
||||
// Sort the leaves by column
|
||||
accumulated_leaves.par_sort_by_key(|leaf| leaf.col);
|
||||
|
||||
// TODO: determine is there is any notable performance difference between computing
|
||||
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
|
||||
// subtree boundaries as we go. Either way this function is only used at the beginning of a
|
||||
// parallel construction, so it should not be a critical path.
|
||||
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
|
||||
accumulator
|
||||
}
|
||||
}
|
||||
|
||||
// SUBTREES
|
||||
// ================================================================================================
|
||||
|
||||
/// A subtree is of depth 8.
|
||||
const SUBTREE_DEPTH: u8 = 8;
|
||||
|
||||
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
|
||||
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
|
||||
|
||||
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
|
||||
///
|
||||
/// Note that these represent "conceptual" leaves of some subtree, not necessarily
|
||||
/// the leaf type for the sparse Merkle tree.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||
pub struct SubtreeLeaf {
|
||||
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
|
||||
pub col: u64,
|
||||
/// The hash of the node this `SubtreeLeaf` represents.
|
||||
pub hash: RpoDigest,
|
||||
}
|
||||
|
||||
/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct PairComputations<K, L> {
|
||||
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
|
||||
pub nodes: UnorderedMap<K, L>,
|
||||
/// "Conceptual" leaves that will be used for computations.
|
||||
pub leaves: Vec<Vec<SubtreeLeaf>>,
|
||||
}
|
||||
|
||||
// Derive requires `L` to impl Default, even though we don't actually need that.
|
||||
impl<K, L> Default for PairComputations<K, L> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
nodes: Default::default(),
|
||||
leaves: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SubtreeLeavesIter<'s> {
|
||||
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
|
||||
}
|
||||
|
||||
impl<'s> SubtreeLeavesIter<'s> {
|
||||
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
|
||||
// TODO: determine if there is any notable performance difference between taking a Vec,
|
||||
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
|
||||
// The latter may have self-referential properties that are impossible to express in purely
|
||||
// safe Rust Rust.
|
||||
Self { leaves: leaves.drain(..).peekable() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for SubtreeLeavesIter<'_> {
|
||||
type Item = Vec<SubtreeLeaf>;
|
||||
|
||||
/// Each `next()` collects an entire subtree.
|
||||
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
|
||||
let mut subtree: Vec<SubtreeLeaf> = Default::default();
|
||||
|
||||
let mut last_subtree_col = 0;
|
||||
|
||||
while let Some(leaf) = self.leaves.peek() {
|
||||
last_subtree_col = u64::max(1, last_subtree_col);
|
||||
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
|
||||
let next_subtree_col = if is_exact_multiple {
|
||||
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
|
||||
} else {
|
||||
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
|
||||
};
|
||||
|
||||
last_subtree_col = leaf.col;
|
||||
if leaf.col < next_subtree_col {
|
||||
subtree.push(self.leaves.next().unwrap());
|
||||
} else if subtree.is_empty() {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if subtree.is_empty() {
|
||||
debug_assert!(self.leaves.peek().is_none());
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(subtree)
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
|
||||
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
|
||||
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
|
||||
///
|
||||
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
|
||||
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
|
||||
/// itself.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
|
||||
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
|
||||
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
|
||||
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
|
||||
fn build_subtree(
|
||||
mut leaves: Vec<SubtreeLeaf>,
|
||||
tree_depth: u8,
|
||||
bottom_depth: u8,
|
||||
) -> (UnorderedMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
// Ensure that all leaves have unique column indices within this subtree.
|
||||
// In normal usage via public APIs, this should never happen because leaf
|
||||
// construction enforces uniqueness. However, when testing or benchmarking
|
||||
// `build_subtree()` in isolation, duplicate columns can appear if input
|
||||
// constraints are not enforced.
|
||||
let mut seen_cols = BTreeSet::new();
|
||||
for leaf in &leaves {
|
||||
assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col);
|
||||
}
|
||||
}
|
||||
debug_assert!(bottom_depth <= tree_depth);
|
||||
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
|
||||
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
|
||||
let subtree_root = bottom_depth - SUBTREE_DEPTH;
|
||||
let mut inner_nodes: UnorderedMap<NodeIndex, InnerNode> = Default::default();
|
||||
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
|
||||
for next_depth in (subtree_root..bottom_depth).rev() {
|
||||
debug_assert!(next_depth <= bottom_depth);
|
||||
// `next_depth` is the stuff we're making.
|
||||
// `current_depth` is the stuff we have.
|
||||
let current_depth = next_depth + 1;
|
||||
let mut iter = leaves.drain(..).peekable();
|
||||
while let Some(first) = iter.next() {
|
||||
// On non-continuous iterations, including the first iteration, `first_column` may
|
||||
// be a left or right node. On subsequent continuous iterations, we will always call
|
||||
// `iter.next()` twice.
|
||||
// On non-continuous iterations (including the very first iteration), this column
|
||||
// could be either on the left or the right. If the next iteration is not
|
||||
// discontinuous with our right node, then the next iteration's
|
||||
let is_right = first.col.is_odd();
|
||||
let (left, right) = if is_right {
|
||||
// Discontinuous iteration: we have no left node, so it must be empty.
|
||||
let left = SubtreeLeaf {
|
||||
col: first.col - 1,
|
||||
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
||||
};
|
||||
let right = first;
|
||||
(left, right)
|
||||
} else {
|
||||
let left = first;
|
||||
let right_col = first.col + 1;
|
||||
let right = match iter.peek().copied() {
|
||||
Some(SubtreeLeaf { col, .. }) if col == right_col => {
|
||||
// Our inputs must be sorted.
|
||||
debug_assert!(left.col <= col);
|
||||
// The next leaf in the iterator is our sibling. Use it and consume it!
|
||||
iter.next().unwrap()
|
||||
},
|
||||
// Otherwise, the leaves don't contain our sibling, so our sibling must be
|
||||
// empty.
|
||||
_ => SubtreeLeaf {
|
||||
col: right_col,
|
||||
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
||||
},
|
||||
};
|
||||
(left, right)
|
||||
};
|
||||
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
|
||||
let node = InnerNode { left: left.hash, right: right.hash };
|
||||
let hash = node.hash();
|
||||
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
|
||||
// If this hash is empty, then it doesn't become a new inner node, nor does it count
|
||||
// as a leaf for the next depth.
|
||||
if hash != equivalent_empty_hash {
|
||||
inner_nodes.insert(index, node);
|
||||
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
|
||||
}
|
||||
}
|
||||
// Stop borrowing `leaves`, so we can swap it.
|
||||
// The iterator is empty at this point anyway.
|
||||
drop(iter);
|
||||
// After each depth, consider the stuff we just made the new "leaves", and empty the
|
||||
// other collection.
|
||||
mem::swap(&mut leaves, &mut next_leaves);
|
||||
}
|
||||
debug_assert_eq!(leaves.len(), 1);
|
||||
let root = leaves.pop().unwrap();
|
||||
(inner_nodes, root)
|
||||
}
|
||||
|
||||
/// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part:
|
||||
/// - If `first_leaf` is a right child, the left child is copied from the `parent_node`.
|
||||
/// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also mutated
|
||||
/// or copied from the `parent_node`.
|
||||
///
|
||||
/// Returns the `InnerNode` containing the hashes of the sibling pair.
|
||||
fn fetch_sibling_pair(
|
||||
iter: &mut core::iter::Peekable<alloc::vec::Drain<SubtreeLeaf>>,
|
||||
first_leaf: SubtreeLeaf,
|
||||
parent_node: InnerNode,
|
||||
) -> InnerNode {
|
||||
let is_right_node = first_leaf.col.is_odd();
|
||||
|
||||
if is_right_node {
|
||||
let left_leaf = SubtreeLeaf {
|
||||
col: first_leaf.col - 1,
|
||||
hash: parent_node.left,
|
||||
};
|
||||
InnerNode {
|
||||
left: left_leaf.hash,
|
||||
right: first_leaf.hash,
|
||||
}
|
||||
} else {
|
||||
let right_col = first_leaf.col + 1;
|
||||
let right_leaf = match iter.peek().copied() {
|
||||
Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(),
|
||||
_ => SubtreeLeaf { col: right_col, hash: parent_node.right },
|
||||
};
|
||||
InnerNode {
|
||||
left: first_leaf.hash,
|
||||
right: right_leaf.hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "internal")]
|
||||
pub fn build_subtree_for_bench(
|
||||
leaves: Vec<SubtreeLeaf>,
|
||||
tree_depth: u8,
|
||||
bottom_depth: u8,
|
||||
) -> (UnorderedMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
||||
build_subtree(leaves, tree_depth, bottom_depth)
|
||||
}
|
459
src/merkle/smt/full/concurrent/tests.rs
Normal file
459
src/merkle/smt/full/concurrent/tests.rs
Normal file
|
@ -0,0 +1,459 @@
|
|||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use rand::{prelude::IteratorRandom, thread_rng, Rng};
|
||||
|
||||
use super::{
|
||||
build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest,
|
||||
Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE,
|
||||
SMT_DEPTH, SUBTREE_DEPTH,
|
||||
};
|
||||
use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE};
|
||||
|
||||
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
|
||||
SubtreeLeaf {
|
||||
col: leaf.index().index.value(),
|
||||
hash: leaf.hash(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sorted_pairs_to_leaves() {
|
||||
let entries: Vec<(RpoDigest, Word)> = vec![
|
||||
// Subtree 0.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]),
|
||||
// Leaf index collision.
|
||||
(RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]),
|
||||
// Subtree 1. Normal single leaf again.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]),
|
||||
// Subtree 2. Another normal leaf.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
|
||||
];
|
||||
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let control_leaves: Vec<SmtLeaf> = {
|
||||
let mut entries_iter = entries.iter().cloned();
|
||||
let mut next_entry = || entries_iter.next().unwrap();
|
||||
let control_leaves = vec![
|
||||
// Subtree 0.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
|
||||
// Subtree 1.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
// Subtree 2.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
];
|
||||
assert_eq!(entries_iter.next(), None);
|
||||
control_leaves
|
||||
};
|
||||
|
||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
|
||||
let mut control_leaves_iter = control_leaves.iter();
|
||||
let mut next_leaf = || control_leaves_iter.next().unwrap();
|
||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
|
||||
// Subtree 0.
|
||||
vec![next_leaf(), next_leaf(), next_leaf()],
|
||||
// Subtree 1.
|
||||
vec![next_leaf(), next_leaf()],
|
||||
// Subtree 2.
|
||||
vec![next_leaf()],
|
||||
]
|
||||
.map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect())
|
||||
.to_vec();
|
||||
assert_eq!(control_leaves_iter.next(), None);
|
||||
control_subtree_leaves
|
||||
};
|
||||
|
||||
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
|
||||
// This will check that the hashes, columns, and subtree assignments all match.
|
||||
assert_eq!(subtrees.leaves, control_subtree_leaves);
|
||||
// Flattening and re-separating out the leaves into subtrees should have the same result.
|
||||
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
|
||||
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
|
||||
assert_eq!(subtrees.leaves, re_grouped);
|
||||
// Then finally we might as well check the computed leaf nodes too.
|
||||
let control_leaves: BTreeMap<u64, SmtLeaf> = control
|
||||
.leaves()
|
||||
.map(|(index, value)| (index.index.value(), value.clone()))
|
||||
.collect();
|
||||
|
||||
for (column, test_leaf) in subtrees.nodes {
|
||||
if test_leaf.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let control_leaf = control_leaves
|
||||
.get(&column)
|
||||
.unwrap_or_else(|| panic!("no leaf node found for column {column}"));
|
||||
assert_eq!(control_leaf, &test_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for the below tests.
|
||||
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
|
||||
(0..pair_count)
|
||||
.map(|i| {
|
||||
let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64;
|
||||
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
|
||||
let value = [ONE, ONE, ONE, Felt::new(i)];
|
||||
(key, value)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> {
|
||||
const REMOVAL_PROBABILITY: f64 = 0.2;
|
||||
let mut rng = thread_rng();
|
||||
// Assertion to ensure input keys are unique
|
||||
assert!(
|
||||
entries.iter().map(|(key, _)| key).collect::<BTreeSet<_>>().len() == entries.len(),
|
||||
"Input entries contain duplicate keys!"
|
||||
);
|
||||
let mut sorted_entries: Vec<(RpoDigest, Word)> = entries
|
||||
.into_iter()
|
||||
.choose_multiple(&mut rng, updates)
|
||||
.into_iter()
|
||||
.map(|(key, _)| {
|
||||
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
|
||||
EMPTY_WORD
|
||||
} else {
|
||||
[ONE, ONE, ONE, Felt::new(rng.gen())]
|
||||
};
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value());
|
||||
sorted_entries
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_subtree() {
|
||||
// A single subtree's worth of leaves.
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
// `entries` should already be sorted by nature of how we constructed it.
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
|
||||
let leaves = leaves.into_iter().next().unwrap();
|
||||
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
|
||||
assert!(!first_subtree.is_empty());
|
||||
// The inner nodes computed from that subtree should match the nodes in our control tree.
|
||||
for (index, node) in first_subtree.into_iter() {
|
||||
let control = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
control, node,
|
||||
"subtree-computed node at index {index:?} does not match control",
|
||||
);
|
||||
}
|
||||
// The root returned should also match the equivalent node in the control tree.
|
||||
let control_root_index =
|
||||
NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index");
|
||||
let control_root_node = control.get_inner_node(control_root_index);
|
||||
let control_hash = control_root_node.hash();
|
||||
assert_eq!(
|
||||
control_hash, subtree_root.hash,
|
||||
"Subtree-computed root at index {control_root_index:?} does not match control"
|
||||
);
|
||||
}
|
||||
|
||||
// Test that not just can we compute a subtree correctly, but we can feed the results of one
|
||||
// subtree into computing another. In other words, test that `build_subtree()` is correctly
|
||||
// composable.
|
||||
#[test]
|
||||
fn test_two_subtrees() {
|
||||
// Two subtrees' worth of leaves.
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
|
||||
// With two subtrees' worth of leaves, we should have exactly two subtrees.
|
||||
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
|
||||
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
|
||||
assert_eq!(first.len(), second.len());
|
||||
let mut current_depth = SMT_DEPTH;
|
||||
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
|
||||
let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth);
|
||||
next_leaves.push(first_root);
|
||||
let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth);
|
||||
next_leaves.push(second_root);
|
||||
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
|
||||
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
|
||||
assert_eq!(total_computed as u64, PAIR_COUNT);
|
||||
// Verify the computed nodes of both subtrees.
|
||||
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
|
||||
for (index, test_node) in computed_nodes {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
control_node, test_node,
|
||||
"subtree-computed node at index {index:?} does not match control",
|
||||
);
|
||||
}
|
||||
current_depth -= SUBTREE_DEPTH;
|
||||
let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth);
|
||||
assert_eq!(nodes.len(), SUBTREE_DEPTH as usize);
|
||||
assert_eq!(root_leaf.col, 0);
|
||||
for (index, test_node) in nodes {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
control_node, test_node,
|
||||
"subtree-computed node at index {index:?} does not match control",
|
||||
);
|
||||
}
|
||||
let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap();
|
||||
let control_root = control.get_inner_node(index).hash();
|
||||
assert_eq!(control_root, root_leaf.hash, "Root mismatch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_singlethreaded_subtrees() {
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = Smt::sorted_pairs_to_leaves(entries);
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
// There's no flat_map_unzip(), so this is the best we can do.
|
||||
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, subtree)| {
|
||||
// Pre-assertions.
|
||||
assert!(
|
||||
subtree.is_sorted(),
|
||||
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||
);
|
||||
assert!(
|
||||
!subtree.is_empty(),
|
||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||
);
|
||||
// Do actual things.
|
||||
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||
// Post-assertions.
|
||||
for (&index, test_node) in nodes.iter() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
test_node, &control_node,
|
||||
"depth {} subtree {}: test node does not match control at index {:?}",
|
||||
current_depth, i, index,
|
||||
);
|
||||
}
|
||||
(nodes, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
// Update state between each depth iteration.
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||
}
|
||||
// Make sure the true leaves match, first checking length and then checking each individual
|
||||
// leaf.
|
||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||
let control_leaves_len = control_leaves.len();
|
||||
let test_leaves_len = test_leaves.len();
|
||||
assert_eq!(test_leaves_len, control_leaves_len);
|
||||
for (col, ref test_leaf) in test_leaves {
|
||||
let index = LeafIndex::new_max_depth(col);
|
||||
let &control_leaf = control_leaves.get(&index).unwrap();
|
||||
assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control");
|
||||
}
|
||||
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
||||
let control_nodes_len = control.inner_nodes().count();
|
||||
let test_nodes_len = accumulated_nodes.len();
|
||||
assert_eq!(test_nodes_len, control_nodes_len);
|
||||
for (index, test_node) in accumulated_nodes.clone() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||
}
|
||||
// After the last iteration of the above for loop, we should have the new root node actually
|
||||
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||
// `build_subtree()`. So let's check both!
|
||||
let control_root = control.get_inner_node(NodeIndex::root());
|
||||
// That for loop should have left us with only one leaf subtree...
|
||||
let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap();
|
||||
// which itself contains only one 'leaf'...
|
||||
let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap();
|
||||
// which matches the expected root.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||
// and it should match our actual root.
|
||||
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(control_root, *test_root);
|
||||
// And of course the root we got from each place should match.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
}
|
||||
|
||||
/// The parallel version of `test_singlethreaded_subtree()`.
|
||||
#[test]
|
||||
fn test_multithreaded_subtrees() {
|
||||
use rayon::prelude::*;
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = Smt::sorted_pairs_to_leaves(entries);
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
.into_par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, subtree)| {
|
||||
// Pre-assertions.
|
||||
assert!(
|
||||
subtree.is_sorted(),
|
||||
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||
);
|
||||
assert!(
|
||||
!subtree.is_empty(),
|
||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||
);
|
||||
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||
// Post-assertions.
|
||||
for (&index, test_node) in nodes.iter() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
test_node, &control_node,
|
||||
"depth {} subtree {}: test node does not match control at index {:?}",
|
||||
current_depth, i, index,
|
||||
);
|
||||
}
|
||||
(nodes, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||
}
|
||||
// Make sure the true leaves match, checking length first and then each individual leaf.
|
||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||
let control_leaves_len = control_leaves.len();
|
||||
let test_leaves_len = test_leaves.len();
|
||||
assert_eq!(test_leaves_len, control_leaves_len);
|
||||
for (col, ref test_leaf) in test_leaves {
|
||||
let index = LeafIndex::new_max_depth(col);
|
||||
let &control_leaf = control_leaves.get(&index).unwrap();
|
||||
assert_eq!(test_leaf, control_leaf);
|
||||
}
|
||||
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
||||
let control_nodes_len = control.inner_nodes().count();
|
||||
let test_nodes_len = accumulated_nodes.len();
|
||||
assert_eq!(test_nodes_len, control_nodes_len);
|
||||
for (index, test_node) in accumulated_nodes.clone() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||
}
|
||||
// After the last iteration of the above for loop, we should have the new root node actually
|
||||
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||
// `build_subtree()`. So let's check both!
|
||||
let control_root = control.get_inner_node(NodeIndex::root());
|
||||
// That for loop should have left us with only one leaf subtree...
|
||||
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
|
||||
// which itself contains only one 'leaf'...
|
||||
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
|
||||
// which matches the expected root.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||
// and it should match our actual root.
|
||||
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(control_root, *test_root);
|
||||
// And of course the root we got from each place should match.
|
||||
assert_eq!(control.root(), root_leaf.hash);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_entries_concurrent() {
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let smt = Smt::with_entries(entries.clone()).unwrap();
|
||||
assert_eq!(smt.root(), control.root());
|
||||
assert_eq!(smt, control);
|
||||
}
|
||||
|
||||
/// Concurrent mutations
|
||||
#[test]
|
||||
fn test_singlethreaded_subtree_mutations() {
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let updates = generate_updates(entries.clone(), 1000);
|
||||
let tree = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let control = tree.compute_mutations_sequential(updates.clone());
|
||||
let mut node_mutations = NodeMutations::default();
|
||||
let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates);
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
// There's no flat_map_unzip(), so this is the best we can do.
|
||||
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, subtree)| {
|
||||
// Pre-assertions.
|
||||
assert!(
|
||||
subtree.is_sorted(),
|
||||
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||
);
|
||||
assert!(
|
||||
!subtree.is_empty(),
|
||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||
);
|
||||
// Calculate the mutations for this subtree.
|
||||
let (mutations_per_subtree, subtree_root) =
|
||||
tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth);
|
||||
// Check that the mutations match the control tree.
|
||||
for (&index, mutation) in mutations_per_subtree.iter() {
|
||||
let control_mutation = control.node_mutations().get(&index).unwrap();
|
||||
assert_eq!(
|
||||
control_mutation, mutation,
|
||||
"depth {} subtree {}: mutation does not match control at index {:?}",
|
||||
current_depth, i, index,
|
||||
);
|
||||
}
|
||||
(mutations_per_subtree, subtree_root)
|
||||
})
|
||||
.unzip();
|
||||
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
|
||||
assert!(!subtree_leaves.is_empty(), "on depth {current_depth}");
|
||||
}
|
||||
let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap();
|
||||
let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap();
|
||||
// Check that the new root matches the control.
|
||||
assert_eq!(control.new_root, root_leaf.hash);
|
||||
// Check that the node mutations match the control.
|
||||
assert_eq!(control.node_mutations().len(), node_mutations.len());
|
||||
for (&index, mutation) in control.node_mutations().iter() {
|
||||
let test_mutation = node_mutations.get(&index).unwrap();
|
||||
assert_eq!(test_mutation, mutation);
|
||||
}
|
||||
// Check that the new pairs match the control
|
||||
assert_eq!(control.new_pairs.len(), new_pairs.len());
|
||||
for (&key, &value) in control.new_pairs.iter() {
|
||||
let test_value = new_pairs.get(&key).unwrap();
|
||||
assert_eq!(test_value, &value);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_mutations_parallel() {
|
||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let tree = Smt::with_entries(entries.clone()).unwrap();
|
||||
let updates = generate_updates(entries, 1000);
|
||||
let control = tree.compute_mutations_sequential(updates.clone());
|
||||
let mutations = tree.compute_mutations(updates);
|
||||
assert_eq!(mutations.root(), control.root());
|
||||
assert_eq!(mutations.old_root(), control.old_root());
|
||||
assert_eq!(mutations.node_mutations(), control.node_mutations());
|
||||
assert_eq!(mutations.new_pairs(), control.new_pairs());
|
||||
}
|
39
src/merkle/smt/full/error.rs
Normal file
39
src/merkle/smt/full/error.rs
Normal file
|
@ -0,0 +1,39 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{LeafIndex, SMT_DEPTH},
|
||||
};
|
||||
|
||||
// SMT LEAF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SmtLeafError {
|
||||
#[error(
|
||||
"multiple leaf requires all keys to map to the same leaf index but key1 {key_1} and key2 {key_2} map to different indices"
|
||||
)]
|
||||
InconsistentMultipleLeafKeys { key_1: RpoDigest, key_2: RpoDigest },
|
||||
#[error("single leaf key {key} maps to {actual_leaf_index:?} but was expected to map to {expected_leaf_index:?}")]
|
||||
InconsistentSingleLeafIndices {
|
||||
key: RpoDigest,
|
||||
expected_leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
actual_leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
#[error("supplied leaf index {leaf_index_supplied:?} does not match {leaf_index_from_keys:?} for multiple leaf")]
|
||||
InconsistentMultipleLeafIndices {
|
||||
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
|
||||
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
|
||||
},
|
||||
#[error("multiple leaf requires at least two entries but only {0} were given")]
|
||||
MultipleLeafRequiresTwoEntries(usize),
|
||||
}
|
||||
|
||||
// SMT PROOF ERROR
|
||||
// =================================================================================================
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SmtProofError {
|
||||
#[error("merkle path length {0} does not match SMT depth {SMT_DEPTH}")]
|
||||
InvalidMerklePathLength(usize),
|
||||
}
|
373
src/merkle/smt/full/leaf.rs
Normal file
373
src/merkle/smt/full/leaf.rs
Normal file
|
@ -0,0 +1,373 @@
|
|||
use alloc::{string::ToString, vec::Vec};
|
||||
use core::cmp::Ordering;
|
||||
|
||||
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub enum SmtLeaf {
|
||||
Empty(LeafIndex<SMT_DEPTH>),
|
||||
Single((RpoDigest, Word)),
|
||||
Multiple(Vec<(RpoDigest, Word)>),
|
||||
}
|
||||
|
||||
impl SmtLeaf {
|
||||
// CONSTRUCTORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new leaf with the specified entries
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
/// - Returns an error if 1 or more keys in `entries` map to a leaf index different from
|
||||
/// `leaf_index`
|
||||
pub fn new(
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
leaf_index: LeafIndex<SMT_DEPTH>,
|
||||
) -> Result<Self, SmtLeafError> {
|
||||
match entries.len() {
|
||||
0 => Ok(Self::new_empty(leaf_index)),
|
||||
1 => {
|
||||
let (key, value) = entries[0];
|
||||
|
||||
let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
|
||||
if computed_index != leaf_index {
|
||||
return Err(SmtLeafError::InconsistentSingleLeafIndices {
|
||||
key,
|
||||
expected_leaf_index: leaf_index,
|
||||
actual_leaf_index: computed_index,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self::new_single(key, value))
|
||||
},
|
||||
_ => {
|
||||
let leaf = Self::new_multiple(entries)?;
|
||||
|
||||
// `new_multiple()` checked that all keys map to the same leaf index. We still need
|
||||
// to ensure that leaf index is `leaf_index`.
|
||||
if leaf.index() != leaf_index {
|
||||
Err(SmtLeafError::InconsistentMultipleLeafIndices {
|
||||
leaf_index_from_keys: leaf.index(),
|
||||
leaf_index_supplied: leaf_index,
|
||||
})
|
||||
} else {
|
||||
Ok(leaf)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new empty leaf with the specified leaf index
|
||||
pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
|
||||
Self::Empty(leaf_index)
|
||||
}
|
||||
|
||||
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
|
||||
/// entry's key.
|
||||
pub fn new_single(key: RpoDigest, value: Word) -> Self {
|
||||
Self::Single((key, value))
|
||||
}
|
||||
|
||||
/// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
|
||||
/// entries' keys.
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns an error if 2 keys in `entries` map to a different leaf index
|
||||
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
|
||||
if entries.len() < 2 {
|
||||
return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
|
||||
}
|
||||
|
||||
// Check that all keys map to the same leaf index
|
||||
{
|
||||
let mut keys = entries.iter().map(|(key, _)| key);
|
||||
|
||||
let first_key = *keys.next().expect("ensured at least 2 entries");
|
||||
let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
|
||||
|
||||
for &next_key in keys {
|
||||
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
|
||||
|
||||
if next_leaf_index != first_leaf_index {
|
||||
return Err(SmtLeafError::InconsistentMultipleLeafKeys {
|
||||
key_1: first_key,
|
||||
key_2: next_key,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::Multiple(entries))
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if the leaf is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
matches!(self, Self::Empty(_))
|
||||
}
|
||||
|
||||
/// Returns the leaf's index in the [`super::Smt`]
|
||||
pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
|
||||
match self {
|
||||
SmtLeaf::Empty(leaf_index) => *leaf_index,
|
||||
SmtLeaf::Single((key, _)) => key.into(),
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
// Note: All keys are guaranteed to have the same leaf index
|
||||
let (first_key, _) = entries[0];
|
||||
first_key.into()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of entries stored in the leaf
|
||||
pub fn num_entries(&self) -> u64 {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => 0,
|
||||
SmtLeaf::Single(_) => 1,
|
||||
SmtLeaf::Multiple(entries) => {
|
||||
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the hash of the leaf
|
||||
pub fn hash(&self) -> RpoDigest {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => EMPTY_WORD.into(),
|
||||
SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
|
||||
SmtLeaf::Multiple(kvs) => {
|
||||
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
|
||||
Rpo256::hash_elements(&elements)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the key-value pairs in the leaf
|
||||
pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Vec::new(),
|
||||
SmtLeaf::Single(kv_pair) => vec![kv_pair],
|
||||
SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Converts a leaf to a list of field elements
|
||||
pub fn to_elements(&self) -> Vec<Felt> {
|
||||
self.clone().into_elements()
|
||||
}
|
||||
|
||||
/// Converts a leaf to a list of field elements
|
||||
pub fn into_elements(self) -> Vec<Felt> {
|
||||
self.into_entries().into_iter().flat_map(kv_to_elements).collect()
|
||||
}
|
||||
|
||||
/// Converts a leaf the key-value pairs in the leaf
|
||||
pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Vec::new(),
|
||||
SmtLeaf::Single(kv_pair) => vec![kv_pair],
|
||||
SmtLeaf::Multiple(kv_pairs) => kv_pairs,
|
||||
}
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
|
||||
/// leaf.
|
||||
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
|
||||
// Ensure that `key` maps to this leaf
|
||||
if self.index() != key.into() {
|
||||
return None;
|
||||
}
|
||||
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => Some(EMPTY_WORD),
|
||||
SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
|
||||
if key == key_in_leaf {
|
||||
Some(*value_in_leaf)
|
||||
} else {
|
||||
Some(EMPTY_WORD)
|
||||
}
|
||||
},
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
for (key_in_leaf, value_in_leaf) in kv_pairs {
|
||||
if key == key_in_leaf {
|
||||
return Some(*value_in_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
Some(EMPTY_WORD)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
|
||||
/// any.
|
||||
///
|
||||
/// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
|
||||
pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => {
|
||||
*self = SmtLeaf::new_single(key, value);
|
||||
None
|
||||
},
|
||||
SmtLeaf::Single(kv_pair) => {
|
||||
if kv_pair.0 == key {
|
||||
// the key is already in this leaf. Update the value and return the previous
|
||||
// value
|
||||
let old_value = kv_pair.1;
|
||||
kv_pair.1 = value;
|
||||
Some(old_value)
|
||||
} else {
|
||||
// Another entry is present in this leaf. Transform the entry into a list
|
||||
// entry, and make sure the key-value pairs are sorted by key
|
||||
let mut pairs = vec![*kv_pair, (key, value)];
|
||||
pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
|
||||
|
||||
*self = SmtLeaf::Multiple(pairs);
|
||||
|
||||
None
|
||||
}
|
||||
},
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = kv_pairs[pos].1;
|
||||
kv_pairs[pos].1 = value;
|
||||
|
||||
Some(old_value)
|
||||
},
|
||||
Err(pos) => {
|
||||
kv_pairs.insert(pos, (key, value));
|
||||
|
||||
None
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes key-value pair from the leaf stored at key; returns the previous value associated
|
||||
/// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
|
||||
/// empty, and must be removed from the data structure it is contained in.
|
||||
pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
|
||||
match self {
|
||||
SmtLeaf::Empty(_) => (None, false),
|
||||
SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
|
||||
if *key_at_leaf == key {
|
||||
// our key was indeed stored in the leaf, so we return the value that was stored
|
||||
// in it, and indicate that the leaf should be removed
|
||||
let old_value = *value_at_leaf;
|
||||
|
||||
// Note: this is not strictly needed, since the caller is expected to drop this
|
||||
// `SmtLeaf` object.
|
||||
*self = SmtLeaf::new_empty(key.into());
|
||||
|
||||
(Some(old_value), true)
|
||||
} else {
|
||||
// another key is stored at leaf; nothing to update
|
||||
(None, false)
|
||||
}
|
||||
},
|
||||
SmtLeaf::Multiple(kv_pairs) => {
|
||||
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
|
||||
Ok(pos) => {
|
||||
let old_value = kv_pairs[pos].1;
|
||||
|
||||
kv_pairs.remove(pos);
|
||||
debug_assert!(!kv_pairs.is_empty());
|
||||
|
||||
if kv_pairs.len() == 1 {
|
||||
// convert the leaf into `Single`
|
||||
*self = SmtLeaf::Single(kv_pairs[0]);
|
||||
}
|
||||
|
||||
(Some(old_value), false)
|
||||
},
|
||||
Err(_) => {
|
||||
// other keys are stored at leaf; nothing to update
|
||||
(None, false)
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SmtLeaf {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// Write: num entries
|
||||
self.num_entries().write_into(target);
|
||||
|
||||
// Write: leaf index
|
||||
let leaf_index: u64 = self.index().value();
|
||||
leaf_index.write_into(target);
|
||||
|
||||
// Write: entries
|
||||
for (key, value) in self.entries() {
|
||||
key.write_into(target);
|
||||
value.write_into(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SmtLeaf {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
// Read: num entries
|
||||
let num_entries = source.read_u64()?;
|
||||
|
||||
// Read: leaf index
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = {
|
||||
let value = source.read_u64()?;
|
||||
LeafIndex::new_max_depth(value)
|
||||
};
|
||||
|
||||
// Read: entries
|
||||
let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
|
||||
for _ in 0..num_entries {
|
||||
let key: RpoDigest = source.read()?;
|
||||
let value: Word = source.read()?;
|
||||
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
Self::new(entries, leaf_index)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Converts a key-value tuple to an iterator of `Felt`s
|
||||
pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = value.into_iter();
|
||||
|
||||
key_elements.chain(value_elements)
|
||||
}
|
||||
|
||||
/// Compares two keys, compared element-by-element using their integer representations starting with
|
||||
/// the most significant element.
|
||||
pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
|
||||
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
|
||||
let v1 = v1.as_int();
|
||||
let v2 = v2.as_int();
|
||||
if v1 != v2 {
|
||||
return v1.cmp(&v2);
|
||||
}
|
||||
}
|
||||
|
||||
Ordering::Equal
|
||||
}
|
548
src/merkle/smt/full/mod.rs
Normal file
548
src/merkle/smt/full/mod.rs
Normal file
|
@ -0,0 +1,548 @@
|
|||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use super::{
|
||||
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
|
||||
MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
};
|
||||
|
||||
mod error;
|
||||
pub use error::{SmtLeafError, SmtProofError};
|
||||
|
||||
mod leaf;
|
||||
pub use leaf::SmtLeaf;
|
||||
|
||||
mod proof;
|
||||
pub use proof::SmtProof;
|
||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
// Concurrent implementation
|
||||
#[cfg(feature = "concurrent")]
|
||||
mod concurrent;
|
||||
#[cfg(feature = "internal")]
|
||||
pub use concurrent::{build_subtree_for_bench, SubtreeLeaf};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
pub const SMT_DEPTH: u8 = 64;
|
||||
|
||||
// SMT
|
||||
// ================================================================================================
|
||||
|
||||
type Leaves = super::Leaves<SmtLeaf>;
|
||||
|
||||
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
|
||||
/// by 4 field elements.
|
||||
///
|
||||
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf
|
||||
/// to which the key maps.
|
||||
///
|
||||
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
|
||||
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
|
||||
/// second.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct Smt {
|
||||
root: RpoDigest,
|
||||
// pub(super) for use in PartialSmt.
|
||||
pub(super) leaves: Leaves,
|
||||
inner_nodes: InnerNodes,
|
||||
}
|
||||
|
||||
impl Smt {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<SMT_DEPTH>>::EMPTY_VALUE;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [Smt].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [Self::EMPTY_VALUE].
|
||||
pub fn new() -> Self {
|
||||
let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
|
||||
Self {
|
||||
root,
|
||||
inner_nodes: Default::default(),
|
||||
leaves: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// If the `concurrent` feature is enabled, this function uses a parallel implementation to
|
||||
/// process the entries efficiently, otherwise it defaults to the sequential implementation.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub fn with_entries(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
#[cfg(feature = "concurrent")]
|
||||
{
|
||||
Self::with_entries_concurrent(entries)
|
||||
}
|
||||
#[cfg(not(feature = "concurrent"))]
|
||||
{
|
||||
Self::with_entries_sequential(entries)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new [Smt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// This sequential implementation processes entries one at a time to build the tree.
|
||||
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
#[cfg(any(not(feature = "concurrent"), test))]
|
||||
fn with_entries_sequential(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
use alloc::collections::BTreeSet;
|
||||
|
||||
// create an empty tree
|
||||
let mut tree = Self::new();
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (key, value) in entries {
|
||||
let old_value = tree.insert(key, value);
|
||||
|
||||
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(
|
||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||
));
|
||||
}
|
||||
|
||||
if value == EMPTY_WORD {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Returns a new [`Smt`] instantiated from already computed leaves and nodes.
|
||||
///
|
||||
/// This function performs minimal consistency checking. It is the caller's responsibility to
|
||||
/// ensure the passed arguments are correct and consistent with each other.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if `root` does not match the root node in
|
||||
/// `inner_nodes`.
|
||||
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
|
||||
// Our particular implementation of `from_raw_parts()` never returns `Err`.
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the tree
|
||||
pub const fn depth(&self) -> u8 {
|
||||
SMT_DEPTH
|
||||
}
|
||||
|
||||
/// Returns the root of the tree
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the number of non-empty leaves in this tree.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.leaves.len()
|
||||
}
|
||||
|
||||
/// Returns the leaf to which `key` maps
|
||||
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
|
||||
}
|
||||
|
||||
/// Returns the value associated with `key`
|
||||
pub fn get_value(&self, key: &RpoDigest) -> Word {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key)
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
pub fn open(&self, key: &RpoDigest) -> SmtProof {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
/// Returns a boolean value indicating whether the SMT is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
|
||||
self.root == Self::EMPTY_ROOT
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [Smt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
|
||||
self.leaves
|
||||
.iter()
|
||||
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the key-value pairs of this [Smt].
|
||||
pub fn entries(&self) -> impl Iterator<Item = &(RpoDigest, Word)> {
|
||||
self.leaves().flat_map(|(_, leaf)| leaf.entries())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this [Smt].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.inner_nodes.values().map(|e| InnerNodeInfo {
|
||||
value: e.hash(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Self::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Word {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
|
||||
/// tree, allowing for validation before applying those changes.
|
||||
///
|
||||
/// This method returns a [`MutationSet`], which contains all the information for inserting
|
||||
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
|
||||
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
|
||||
/// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle
|
||||
/// tree, or [`drop()`] to discard them.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
|
||||
/// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH};
|
||||
/// let mut smt = Smt::new();
|
||||
/// let pair = (RpoDigest::default(), Word::default());
|
||||
/// let mutations = smt.compute_mutations(vec![pair]);
|
||||
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
/// smt.apply_mutations(mutations);
|
||||
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
/// ```
|
||||
pub fn compute_mutations(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
|
||||
#[cfg(feature = "concurrent")]
|
||||
{
|
||||
self.compute_mutations_concurrent(kv_pairs)
|
||||
}
|
||||
#[cfg(not(feature = "concurrent"))]
|
||||
{
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
|
||||
/// the `mutations` were computed against, and the second item is the actual current root of
|
||||
/// this tree.
|
||||
pub fn apply_mutations(
|
||||
&mut self,
|
||||
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
|
||||
) -> Result<(), MerkleError> {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
|
||||
}
|
||||
|
||||
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree
|
||||
/// and returns the reverse mutation set.
|
||||
///
|
||||
/// Applying the reverse mutation sets to the updated tree will revert the changes.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
|
||||
/// the `mutations` were computed against, and the second item is the actual current root of
|
||||
/// this tree.
|
||||
pub fn apply_mutations_with_reversion(
|
||||
&mut self,
|
||||
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
|
||||
) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts `value` at leaf index pointed to by `key`. `value` is guaranteed to not be the empty
|
||||
/// value, such that this is indeed an insertion.
|
||||
fn perform_insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
|
||||
debug_assert_ne!(value, Self::EMPTY_VALUE);
|
||||
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
match self.leaves.get_mut(&leaf_index.value()) {
|
||||
Some(leaf) => leaf.insert(key, value),
|
||||
None => {
|
||||
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
|
||||
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes key-value pair at leaf index pointed to by `key` if it exists.
|
||||
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
|
||||
let (old_value, is_empty) = leaf.remove(key);
|
||||
if is_empty {
|
||||
self.leaves.remove(&leaf_index.value());
|
||||
}
|
||||
old_value
|
||||
} else {
|
||||
// there's nothing stored at the leaf; nothing to update
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
type Key = RpoDigest;
|
||||
type Value = Word;
|
||||
type Leaf = SmtLeaf;
|
||||
type Opening = SmtProof;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
|
||||
fn from_raw_parts(
|
||||
inner_nodes: InnerNodes,
|
||||
leaves: Leaves,
|
||||
root: RpoDigest,
|
||||
) -> Result<Self, MerkleError> {
|
||||
if cfg!(debug_assertions) {
|
||||
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(root_node.hash(), root);
|
||||
}
|
||||
|
||||
Ok(Self { root, inner_nodes, leaves })
|
||||
}
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
fn set_root(&mut self, root: RpoDigest) {
|
||||
self.root = root;
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes
|
||||
.get(&index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
|
||||
self.inner_nodes.insert(index, inner_node)
|
||||
}
|
||||
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
|
||||
self.inner_nodes.remove(&index)
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
|
||||
// inserting an `EMPTY_VALUE` is equivalent to removing any value associated with `key`
|
||||
if value != Self::EMPTY_VALUE {
|
||||
self.perform_insert(key, value)
|
||||
} else {
|
||||
self.perform_remove(key)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_value(&self, key: &Self::Key) -> Self::Value {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
|
||||
None => EMPTY_WORD,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
|
||||
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
|
||||
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(leaf) => leaf.clone(),
|
||||
None => SmtLeaf::new_empty(key.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest {
|
||||
leaf.hash()
|
||||
}
|
||||
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
mut existing_leaf: SmtLeaf,
|
||||
key: &RpoDigest,
|
||||
value: &Word,
|
||||
) -> SmtLeaf {
|
||||
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
|
||||
|
||||
match existing_leaf {
|
||||
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
|
||||
_ => {
|
||||
if *value != EMPTY_WORD {
|
||||
existing_leaf.insert(*key, *value);
|
||||
} else {
|
||||
existing_leaf.remove(*key);
|
||||
}
|
||||
|
||||
existing_leaf
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
|
||||
let most_significant_felt = key[3];
|
||||
LeafIndex::new_max_depth(most_significant_felt.as_int())
|
||||
}
|
||||
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: SmtLeaf) -> SmtProof {
|
||||
SmtProof::new_unchecked(path, leaf)
|
||||
}
|
||||
|
||||
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf {
|
||||
assert!(!pairs.is_empty());
|
||||
|
||||
if pairs.len() > 1 {
|
||||
SmtLeaf::new_multiple(pairs).unwrap()
|
||||
} else {
|
||||
let (key, value) = pairs.pop().unwrap();
|
||||
// TODO: should we ever be constructing empty leaves from pairs?
|
||||
if value == Self::EMPTY_VALUE {
|
||||
let index = Self::key_to_leaf_index(&key);
|
||||
SmtLeaf::new_empty(index)
|
||||
} else {
|
||||
SmtLeaf::new_single(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Smt {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl From<Word> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: Word) -> Self {
|
||||
// We use the most significant `Felt` of a `Word` as the leaf index.
|
||||
Self::new_max_depth(value[3].as_int())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: RpoDigest) -> Self {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RpoDigest> for LeafIndex<SMT_DEPTH> {
|
||||
fn from(value: &RpoDigest) -> Self {
|
||||
Word::from(value).into()
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for Smt {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
// Write the number of filled leaves for this Smt
|
||||
target.write_usize(self.entries().count());
|
||||
|
||||
// Write each (key, value) pair
|
||||
for (key, value) in self.entries() {
|
||||
target.write(key);
|
||||
target.write(value);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_size_hint(&self) -> usize {
|
||||
let entries_count = self.entries().count();
|
||||
|
||||
// Each entry is the size of a digest plus a word.
|
||||
entries_count.get_size_hint()
|
||||
+ entries_count * (RpoDigest::SERIALIZED_SIZE + EMPTY_WORD.get_size_hint())
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for Smt {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
// Read the number of filled leaves for this Smt
|
||||
let num_filled_leaves = source.read_usize()?;
|
||||
let mut entries = Vec::with_capacity(num_filled_leaves);
|
||||
|
||||
for _ in 0..num_filled_leaves {
|
||||
let key = source.read()?;
|
||||
let value = source.read()?;
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
Self::with_entries(entries)
|
||||
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn test_smt_serialization_deserialization() {
|
||||
// Smt for default types (empty map)
|
||||
let smt_default = Smt::default();
|
||||
let bytes = smt_default.to_bytes();
|
||||
assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap());
|
||||
assert_eq!(bytes.len(), smt_default.get_size_hint());
|
||||
|
||||
// Smt with values
|
||||
let smt_leaves_2: [(RpoDigest, Word); 2] = [
|
||||
(
|
||||
RpoDigest::new([Felt::new(101), Felt::new(102), Felt::new(103), Felt::new(104)]),
|
||||
[Felt::new(1_u64), Felt::new(2_u64), Felt::new(3_u64), Felt::new(4_u64)],
|
||||
),
|
||||
(
|
||||
RpoDigest::new([Felt::new(105), Felt::new(106), Felt::new(107), Felt::new(108)]),
|
||||
[Felt::new(5_u64), Felt::new(6_u64), Felt::new(7_u64), Felt::new(8_u64)],
|
||||
),
|
||||
];
|
||||
let smt = Smt::with_entries(smt_leaves_2).unwrap();
|
||||
|
||||
let bytes = smt.to_bytes();
|
||||
assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap());
|
||||
assert_eq!(bytes.len(), smt.get_size_hint());
|
||||
}
|
115
src/merkle/smt/full/proof.rs
Normal file
115
src/merkle/smt/full/proof.rs
Normal file
|
@ -0,0 +1,115 @@
|
|||
use alloc::string::ToString;
|
||||
|
||||
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
|
||||
/// [`super::Smt`].
|
||||
///
|
||||
/// The proof consists of a Merkle path and leaf which describes the node located at the base of the
|
||||
/// path.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct SmtProof {
|
||||
path: MerklePath,
|
||||
leaf: SmtLeaf,
|
||||
}
|
||||
|
||||
impl SmtProof {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the path length is not [`SMT_DEPTH`].
|
||||
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
|
||||
let depth: usize = SMT_DEPTH.into();
|
||||
if path.len() != depth {
|
||||
return Err(SmtProofError::InvalidMerklePathLength(path.len()));
|
||||
}
|
||||
|
||||
Ok(Self { path, leaf })
|
||||
}
|
||||
|
||||
/// Returns a new instance of [`SmtProof`] instantiated from the specified path and leaf.
|
||||
///
|
||||
/// The length of the path is not checked. Reserved for internal use.
|
||||
pub(super) fn new_unchecked(path: MerklePath, leaf: SmtLeaf) -> Self {
|
||||
Self { path, leaf }
|
||||
}
|
||||
|
||||
// PROOF VERIFIER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns true if a [`super::Smt`] with the specified root contains the provided
|
||||
/// key-value pair.
|
||||
///
|
||||
/// Note: this method cannot be used to assert non-membership. That is, if false is returned,
|
||||
/// it does not mean that the provided key-value pair is not in the tree.
|
||||
pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool {
|
||||
let maybe_value_in_leaf = self.leaf.get_value(key);
|
||||
|
||||
match maybe_value_in_leaf {
|
||||
Some(value_in_leaf) => {
|
||||
// The value must match for the proof to be valid
|
||||
if value_in_leaf != *value {
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the Merkle path resolves to the correct root
|
||||
self.compute_root() == *root
|
||||
},
|
||||
// If the key maps to a different leaf, the proof cannot verify membership of `value`
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the value associated with the specific key according to this proof, or None if
|
||||
/// this proof does not contain a value for the specified key.
|
||||
///
|
||||
/// A key-value pair generated by using this method should pass the `verify_membership()` check.
|
||||
pub fn get(&self, key: &RpoDigest) -> Option<Word> {
|
||||
self.leaf.get_value(key)
|
||||
}
|
||||
|
||||
/// Computes the root of a [`super::Smt`] to which this proof resolves.
|
||||
pub fn compute_root(&self) -> RpoDigest {
|
||||
self.path
|
||||
.compute_root(self.leaf.index().value(), self.leaf.hash())
|
||||
.expect("failed to compute Merkle path root")
|
||||
}
|
||||
|
||||
/// Returns the proof's Merkle path.
|
||||
pub fn path(&self) -> &MerklePath {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Returns the leaf associated with the proof.
|
||||
pub fn leaf(&self) -> &SmtLeaf {
|
||||
&self.leaf
|
||||
}
|
||||
|
||||
/// Consume the proof and returns its parts.
|
||||
pub fn into_parts(self) -> (MerklePath, SmtLeaf) {
|
||||
(self.path, self.leaf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for SmtProof {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.path.write_into(target);
|
||||
self.leaf.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for SmtProof {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let path = MerklePath::read_from(source)?;
|
||||
let leaf = SmtLeaf::read_from(source)?;
|
||||
|
||||
Self::new(path, leaf).map_err(|err| DeserializationError::InvalidValue(err.to_string()))
|
||||
}
|
||||
}
|
724
src/merkle/smt/full/tests.rs
Normal file
724
src/merkle/smt/full/tests.rs
Normal file
|
@ -0,0 +1,724 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
|
||||
use crate::{
|
||||
merkle::{
|
||||
smt::{NodeMutation, SparseMerkleTree, UnorderedMap},
|
||||
EmptySubtreeRoots, MerkleStore, MutationSet,
|
||||
},
|
||||
utils::{Deserializable, Serializable},
|
||||
Word, ONE, WORD_SIZE,
|
||||
};
|
||||
// SMT
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// This test checks that inserting twice at the same key functions as expected. The test covers
|
||||
/// only the case where the key is alone in its leaf
|
||||
#[test]
|
||||
fn test_smt_insert_at_same_key() {
|
||||
let mut smt = Smt::default();
|
||||
let mut store: MerkleStore = MerkleStore::default();
|
||||
|
||||
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
|
||||
let key_1: RpoDigest = {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert value 1 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_1, value_1);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
|
||||
// Insert value 2 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_1, value_2);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_2 = smt.insert(key_1, value_2);
|
||||
assert_eq!(old_value_2, value_1);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
}
|
||||
|
||||
/// This test checks that inserting twice at the same key functions as expected. The test covers
|
||||
/// only the case where the leaf type is `SmtLeaf::Multiple`
|
||||
#[test]
|
||||
fn test_smt_insert_at_same_key_2() {
|
||||
// The most significant u64 used for both keys (to ensure they map to the same leaf)
|
||||
let key_msb: u64 = 42;
|
||||
|
||||
let key_already_present: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(key_msb)]);
|
||||
let key_already_present_index: NodeIndex =
|
||||
LeafIndex::<SMT_DEPTH>::from(key_already_present).into();
|
||||
let value_already_present = [ONE + ONE + ONE; WORD_SIZE];
|
||||
|
||||
let mut smt =
|
||||
Smt::with_entries(core::iter::once((key_already_present, value_already_present))).unwrap();
|
||||
let mut store: MerkleStore = {
|
||||
let mut store = MerkleStore::default();
|
||||
|
||||
let leaf_node = build_empty_or_single_leaf_node(key_already_present, value_already_present);
|
||||
store
|
||||
.set_node(*EmptySubtreeRoots::entry(SMT_DEPTH, 0), key_already_present_index, leaf_node)
|
||||
.unwrap();
|
||||
store
|
||||
};
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(key_msb)]);
|
||||
let key_1_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key_1).into();
|
||||
|
||||
assert_eq!(key_1_index, key_already_present_index);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert value 1 and ensure root is as expected
|
||||
{
|
||||
// Note: key_1 comes first because it is smaller
|
||||
let leaf_node = build_multiple_leaf_node(&[
|
||||
(key_1, value_1),
|
||||
(key_already_present, value_already_present),
|
||||
]);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
|
||||
// Insert value 2 and ensure root is as expected
|
||||
{
|
||||
let leaf_node = build_multiple_leaf_node(&[
|
||||
(key_1, value_2),
|
||||
(key_already_present, value_already_present),
|
||||
]);
|
||||
let tree_root = store.set_node(smt.root(), key_1_index, leaf_node).unwrap().root;
|
||||
|
||||
let old_value_2 = smt.insert(key_1, value_2);
|
||||
assert_eq!(old_value_2, value_1);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
}
|
||||
}
|
||||
|
||||
/// This test ensures that the root of the tree is as expected when we add/remove 3 items at 3
|
||||
/// different keys. This also tests that the merkle paths produced are as expected.
|
||||
#[test]
|
||||
fn test_smt_insert_and_remove_multiple_values() {
|
||||
fn insert_values_and_assert_path(
|
||||
smt: &mut Smt,
|
||||
store: &mut MerkleStore,
|
||||
key_values: &[(RpoDigest, Word)],
|
||||
) {
|
||||
for &(key, value) in key_values {
|
||||
let key_index: NodeIndex = LeafIndex::<SMT_DEPTH>::from(key).into();
|
||||
|
||||
let leaf_node = build_empty_or_single_leaf_node(key, value);
|
||||
let tree_root = store.set_node(smt.root(), key_index, leaf_node).unwrap().root;
|
||||
|
||||
let _ = smt.insert(key, value);
|
||||
|
||||
assert_eq!(smt.root(), tree_root);
|
||||
|
||||
let expected_path = store.get_path(tree_root, key_index).unwrap();
|
||||
assert_eq!(smt.open(&key).into_parts().0, expected_path.path);
|
||||
}
|
||||
}
|
||||
let mut smt = Smt::default();
|
||||
let mut store: MerkleStore = MerkleStore::default();
|
||||
|
||||
assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
|
||||
|
||||
let key_1: RpoDigest = {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let key_2: RpoDigest = {
|
||||
let raw = 0b_11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let key_3: RpoDigest = {
|
||||
let raw = 0b_00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
|
||||
RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)])
|
||||
};
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [ONE + ONE; WORD_SIZE];
|
||||
let value_3 = [ONE + ONE + ONE; WORD_SIZE];
|
||||
|
||||
// Insert values in the tree
|
||||
let key_values = [(key_1, value_1), (key_2, value_2), (key_3, value_3)];
|
||||
insert_values_and_assert_path(&mut smt, &mut store, &key_values);
|
||||
|
||||
// Remove values from the tree
|
||||
let key_empty_values = [(key_1, EMPTY_WORD), (key_2, EMPTY_WORD), (key_3, EMPTY_WORD)];
|
||||
insert_values_and_assert_path(&mut smt, &mut store, &key_empty_values);
|
||||
|
||||
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
|
||||
assert_eq!(smt.root(), empty_root);
|
||||
|
||||
// an empty tree should have no leaves or inner nodes
|
||||
assert!(smt.leaves.is_empty());
|
||||
assert!(smt.inner_nodes.is_empty());
|
||||
}
|
||||
|
||||
/// This tests that inserting the empty value does indeed remove the key-value contained at the
|
||||
/// leaf. We insert & remove 3 values at the same leaf to ensure that all cases are covered (empty,
|
||||
/// single, multiple).
|
||||
#[test]
|
||||
fn test_smt_removal() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
// insert key-value 1
|
||||
{
|
||||
let old_value_1 = smt.insert(key_1, value_1);
|
||||
assert_eq!(old_value_1, EMPTY_WORD);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1)));
|
||||
}
|
||||
|
||||
// insert key-value 2
|
||||
{
|
||||
let old_value_2 = smt.insert(key_2, value_2);
|
||||
assert_eq!(old_value_2, EMPTY_WORD);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_2),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 3
|
||||
{
|
||||
let old_value_3 = smt.insert(key_3, value_3);
|
||||
assert_eq!(old_value_3, EMPTY_WORD);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_3),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)])
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 3
|
||||
{
|
||||
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
|
||||
assert_eq!(old_value_3, value_3);
|
||||
|
||||
assert_eq!(
|
||||
smt.get_leaf(&key_3),
|
||||
SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)])
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 2
|
||||
{
|
||||
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
|
||||
assert_eq!(old_value_2, value_2);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1)));
|
||||
}
|
||||
|
||||
// remove key 1
|
||||
{
|
||||
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
|
||||
assert_eq!(old_value_1, value_1);
|
||||
|
||||
assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into()));
|
||||
}
|
||||
}
|
||||
|
||||
/// This tests that we can correctly calculate prospective leaves -- that is, we can construct
|
||||
/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or
|
||||
/// cloning the tree.
|
||||
#[test]
|
||||
fn test_prospective_hash() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
// Sort key_3 before key_1, to test non-append insertion.
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
// insert key-value 1
|
||||
{
|
||||
let prospective =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash();
|
||||
smt.insert(key_1, value_1);
|
||||
|
||||
let leaf = smt.get_leaf(&key_1);
|
||||
assert_eq!(
|
||||
prospective,
|
||||
leaf.hash(),
|
||||
"prospective hash for leaf {leaf:?} did not match actual hash",
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 2
|
||||
{
|
||||
let prospective =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash();
|
||||
smt.insert(key_2, value_2);
|
||||
|
||||
let leaf = smt.get_leaf(&key_2);
|
||||
assert_eq!(
|
||||
prospective,
|
||||
leaf.hash(),
|
||||
"prospective hash for leaf {leaf:?} did not match actual hash",
|
||||
);
|
||||
}
|
||||
|
||||
// insert key-value 3
|
||||
{
|
||||
let prospective =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash();
|
||||
smt.insert(key_3, value_3);
|
||||
|
||||
let leaf = smt.get_leaf(&key_3);
|
||||
assert_eq!(
|
||||
prospective,
|
||||
leaf.hash(),
|
||||
"prospective hash for leaf {leaf:?} did not match actual hash",
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 3
|
||||
{
|
||||
let old_leaf = smt.get_leaf(&key_3);
|
||||
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
|
||||
assert_eq!(old_value_3, value_3);
|
||||
let prospective_leaf =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3);
|
||||
|
||||
assert_eq!(
|
||||
old_leaf.hash(),
|
||||
prospective_leaf.hash(),
|
||||
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
|
||||
\n original leaf: {old_leaf:?}\
|
||||
\n prospective leaf: {prospective_leaf:?}",
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 2
|
||||
{
|
||||
let old_leaf = smt.get_leaf(&key_2);
|
||||
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
|
||||
assert_eq!(old_value_2, value_2);
|
||||
let prospective_leaf =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2);
|
||||
|
||||
assert_eq!(
|
||||
old_leaf.hash(),
|
||||
prospective_leaf.hash(),
|
||||
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
|
||||
\n original leaf: {old_leaf:?}\
|
||||
\n prospective leaf: {prospective_leaf:?}",
|
||||
);
|
||||
}
|
||||
|
||||
// remove key 1
|
||||
{
|
||||
let old_leaf = smt.get_leaf(&key_1);
|
||||
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
|
||||
assert_eq!(old_value_1, value_1);
|
||||
let prospective_leaf =
|
||||
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1);
|
||||
assert_eq!(
|
||||
old_leaf.hash(),
|
||||
prospective_leaf.hash(),
|
||||
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
|
||||
\n original leaf: {old_leaf:?}\
|
||||
\n prospective leaf: {prospective_leaf:?}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// This tests that we can perform prospective changes correctly.
|
||||
#[test]
|
||||
fn test_prospective_insertion() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
// Sort key_3 before key_1, to test non-append insertion.
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
let root_empty = smt.root();
|
||||
|
||||
let root_1 = {
|
||||
smt.insert(key_1, value_1);
|
||||
smt.root()
|
||||
};
|
||||
|
||||
let root_2 = {
|
||||
smt.insert(key_2, value_2);
|
||||
smt.root()
|
||||
};
|
||||
|
||||
let root_3 = {
|
||||
smt.insert(key_3, value_3);
|
||||
smt.root()
|
||||
};
|
||||
|
||||
// Test incremental updates.
|
||||
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
|
||||
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
|
||||
let revert = apply_mutations(&mut smt, mutations);
|
||||
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
|
||||
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
|
||||
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
|
||||
assert_eq!(
|
||||
revert.new_pairs,
|
||||
UnorderedMap::from_iter([(key_1, EMPTY_WORD)]),
|
||||
"reverse mutations pairs did not match"
|
||||
);
|
||||
assert_eq!(
|
||||
revert.node_mutations,
|
||||
smt.inner_nodes.keys().map(|key| (*key, NodeMutation::Removal)).collect(),
|
||||
"reverse mutations inner nodes did not match"
|
||||
);
|
||||
|
||||
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
|
||||
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
|
||||
let mutations =
|
||||
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
|
||||
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
|
||||
let old_root = smt.root();
|
||||
let revert = apply_mutations(&mut smt, mutations);
|
||||
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
|
||||
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
|
||||
assert_eq!(
|
||||
revert.new_pairs,
|
||||
UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
|
||||
"reverse mutations pairs did not match"
|
||||
);
|
||||
|
||||
// Edge case: multiple values at the same key, where a later pair restores the original value.
|
||||
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
|
||||
assert_eq!(mutations.root(), root_3);
|
||||
let old_root = smt.root();
|
||||
let revert = apply_mutations(&mut smt, mutations);
|
||||
assert_eq!(smt.root(), root_3);
|
||||
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
|
||||
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
|
||||
assert_eq!(
|
||||
revert.new_pairs,
|
||||
UnorderedMap::from_iter([(key_3, value_3)]),
|
||||
"reverse mutations pairs did not match"
|
||||
);
|
||||
|
||||
// Test batch updates, and that the order doesn't matter.
|
||||
let pairs =
|
||||
vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)];
|
||||
let mutations = smt.compute_mutations(pairs);
|
||||
assert_eq!(
|
||||
mutations.root(),
|
||||
root_empty,
|
||||
"prospective root for batch removal did not match actual root",
|
||||
);
|
||||
let old_root = smt.root();
|
||||
let revert = apply_mutations(&mut smt, mutations);
|
||||
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
|
||||
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
|
||||
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
|
||||
assert_eq!(
|
||||
revert.new_pairs,
|
||||
UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
|
||||
"reverse mutations pairs did not match"
|
||||
);
|
||||
|
||||
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
|
||||
let mutations = smt.compute_mutations(pairs);
|
||||
assert_eq!(mutations.root(), root_3);
|
||||
smt.apply_mutations(mutations).unwrap();
|
||||
assert_eq!(smt.root(), root_3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutations_revert() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3 = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
smt.insert(key_1, value_1);
|
||||
smt.insert(key_2, value_2);
|
||||
|
||||
let mutations =
|
||||
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
|
||||
|
||||
let original = smt.clone();
|
||||
|
||||
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
|
||||
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
|
||||
assert_eq!(revert.root(), original.root(), "reverse mutations new root did not match");
|
||||
|
||||
smt.apply_mutations(revert).unwrap();
|
||||
|
||||
assert_eq!(smt, original, "SMT with applied revert mutations did not match original SMT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutation_set_serialization() {
|
||||
let mut smt = Smt::default();
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
|
||||
let key_3: RpoDigest =
|
||||
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let value_3 = [3_u32.into(); WORD_SIZE];
|
||||
|
||||
smt.insert(key_1, value_1);
|
||||
smt.insert(key_2, value_2);
|
||||
|
||||
let mutations =
|
||||
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
|
||||
|
||||
let serialized = mutations.to_bytes();
|
||||
let deserialized =
|
||||
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized, mutations, "deserialized mutations did not match original");
|
||||
|
||||
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
|
||||
|
||||
let serialized = revert.to_bytes();
|
||||
let deserialized =
|
||||
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized, revert, "deserialized mutations did not match original");
|
||||
}
|
||||
|
||||
/// Tests that 2 key-value pairs stored in the same leaf have the same path
|
||||
#[test]
|
||||
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
|
||||
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
|
||||
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
|
||||
let key_2: RpoDigest =
|
||||
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
assert_eq!(smt.open(&key_1), smt.open(&key_2));
|
||||
}
|
||||
|
||||
/// Tests that an empty leaf hashes to the empty word
|
||||
#[test]
|
||||
fn test_empty_leaf_hash() {
|
||||
let smt = Smt::default();
|
||||
|
||||
let leaf = smt.get_leaf(&RpoDigest::default());
|
||||
assert_eq!(leaf.hash(), EMPTY_WORD.into());
|
||||
}
|
||||
|
||||
/// Tests that `get_value()` works as expected
|
||||
#[test]
|
||||
fn test_smt_get_value() {
|
||||
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
|
||||
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap();
|
||||
|
||||
let returned_value_1 = smt.get_value(&key_1);
|
||||
let returned_value_2 = smt.get_value(&key_2);
|
||||
|
||||
assert_eq!(value_1, returned_value_1);
|
||||
assert_eq!(value_2, returned_value_2);
|
||||
|
||||
// Check that a key with no inserted value returns the empty word
|
||||
let key_no_value = RpoDigest::from([42_u32, 42_u32, 42_u32, 42_u32]);
|
||||
|
||||
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
|
||||
}
|
||||
|
||||
/// Tests that `entries()` works as expected
|
||||
#[test]
|
||||
fn test_smt_entries() {
|
||||
let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
|
||||
|
||||
let value_1 = [ONE; WORD_SIZE];
|
||||
let value_2 = [2_u32.into(); WORD_SIZE];
|
||||
let entries = [(key_1, value_1), (key_2, value_2)];
|
||||
|
||||
let smt = Smt::with_entries(entries).unwrap();
|
||||
|
||||
let mut expected = Vec::from_iter(entries);
|
||||
expected.sort_by_key(|(k, _)| *k);
|
||||
let mut actual: Vec<_> = smt.entries().cloned().collect();
|
||||
actual.sort_by_key(|(k, _)| *k);
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of
|
||||
/// depth 64
|
||||
#[test]
|
||||
fn test_smt_check_empty_root_constant() {
|
||||
// get the root of the empty tree of depth 64
|
||||
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
|
||||
assert_eq!(empty_root_64_depth, Smt::EMPTY_ROOT);
|
||||
}
|
||||
|
||||
// SMT LEAF
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_empty_smt_leaf_serialization() {
|
||||
let empty_leaf = SmtLeaf::new_empty(LeafIndex::new_max_depth(42));
|
||||
|
||||
let mut serialized = empty_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(empty_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_smt_leaf_serialization() {
|
||||
let single_leaf = SmtLeaf::new_single(
|
||||
RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
);
|
||||
|
||||
let mut serialized = single_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(single_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_smt_leaf_serialization_success() {
|
||||
let multiple_leaf = SmtLeaf::new_multiple(vec![
|
||||
(
|
||||
RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
|
||||
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
|
||||
),
|
||||
(
|
||||
RpoDigest::from([100_u32, 101_u32, 102_u32, 13_u32]),
|
||||
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
|
||||
),
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let mut serialized = multiple_leaf.to_bytes();
|
||||
// extend buffer with random bytes
|
||||
serialized.extend([1, 2, 3, 4, 5]);
|
||||
let deserialized = SmtLeaf::read_from_bytes(&serialized).unwrap();
|
||||
|
||||
assert_eq!(multiple_leaf, deserialized);
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn build_empty_or_single_leaf_node(key: RpoDigest, value: Word) -> RpoDigest {
|
||||
if value == EMPTY_WORD {
|
||||
SmtLeaf::new_empty(key.into()).hash()
|
||||
} else {
|
||||
SmtLeaf::Single((key, value)).hash()
|
||||
}
|
||||
}
|
||||
|
||||
fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
|
||||
let elements: Vec<Felt> = kv_pairs
|
||||
.iter()
|
||||
.flat_map(|(key, value)| {
|
||||
let key_elements = key.into_iter();
|
||||
let value_elements = (*value).into_iter();
|
||||
|
||||
key_elements.chain(value_elements)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Rpo256::hash_elements(&elements)
|
||||
}
|
||||
|
||||
/// Applies mutations with and without reversion to the given SMT, comparing resulting SMTs,
|
||||
/// returning mutation set for reversion.
|
||||
fn apply_mutations(
|
||||
smt: &mut Smt,
|
||||
mutation_set: MutationSet<SMT_DEPTH, RpoDigest, Word>,
|
||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
|
||||
let mut smt2 = smt.clone();
|
||||
|
||||
let reversion = smt.apply_mutations_with_reversion(mutation_set.clone()).unwrap();
|
||||
smt2.apply_mutations(mutation_set).unwrap();
|
||||
|
||||
assert_eq!(&smt2, smt);
|
||||
|
||||
reversion
|
||||
}
|
712
src/merkle/smt/mod.rs
Normal file
712
src/merkle/smt/mod.rs
Normal file
|
@ -0,0 +1,712 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::hash::Hash;
|
||||
|
||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
|
||||
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
Felt, Word, EMPTY_WORD,
|
||||
};
|
||||
|
||||
mod full;
|
||||
#[cfg(feature = "internal")]
|
||||
pub use full::{build_subtree_for_bench, SubtreeLeaf};
|
||||
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
|
||||
|
||||
mod simple;
|
||||
pub use simple::SimpleSmt;
|
||||
|
||||
mod partial;
|
||||
pub use partial::PartialSmt;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
/// Minimum supported depth.
|
||||
pub const SMT_MIN_DEPTH: u8 = 1;
|
||||
|
||||
/// Maximum supported depth.
|
||||
pub const SMT_MAX_DEPTH: u8 = 64;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
/// A map whose keys are not guarantied to be ordered.
|
||||
#[cfg(feature = "smt_hashmaps")]
|
||||
type UnorderedMap<K, V> = hashbrown::HashMap<K, V>;
|
||||
#[cfg(not(feature = "smt_hashmaps"))]
|
||||
type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
|
||||
type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
|
||||
type Leaves<T> = UnorderedMap<u64, T>;
|
||||
type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
|
||||
|
||||
/// An abstract description of a sparse Merkle tree.
|
||||
///
|
||||
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
|
||||
/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
|
||||
/// value was not explicitly set, then its value is the default value. Typically, the vast majority
|
||||
/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
|
||||
/// representation of the tree will only keep track of the leaves that have a different value from
|
||||
/// the default.
|
||||
///
|
||||
/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
|
||||
/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
|
||||
/// tree is just a single value, and is probably a programming mistake.
|
||||
///
|
||||
/// Every key maps to one leaf. If there are as many keys as there are leaves, then
|
||||
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
|
||||
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
|
||||
/// must accommodate all keys that map to the same leaf.
|
||||
///
|
||||
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
|
||||
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||
/// The type for a key
|
||||
type Key: Clone + Ord + Eq + Hash;
|
||||
/// The type for a value
|
||||
type Value: Clone + PartialEq;
|
||||
/// The type for a leaf
|
||||
type Leaf: Clone;
|
||||
/// The type for an opening (i.e. a "proof") of a leaf
|
||||
type Opening;
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
const EMPTY_VALUE: Self::Value;
|
||||
|
||||
/// The root of the empty tree with provided DEPTH
|
||||
const EMPTY_ROOT: RpoDigest;
|
||||
|
||||
// PROVIDED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
fn open(&self, key: &Self::Key) -> Self::Opening {
|
||||
let leaf = self.get_leaf(key);
|
||||
|
||||
let mut index: NodeIndex = {
|
||||
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
|
||||
leaf_index.into()
|
||||
};
|
||||
|
||||
let merkle_path = {
|
||||
let mut path = Vec::with_capacity(index.depth() as usize);
|
||||
for _ in 0..index.depth() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let InnerNode { left, right } = self.get_inner_node(index);
|
||||
let value = if is_right { left } else { right };
|
||||
path.push(value);
|
||||
}
|
||||
|
||||
MerklePath::new(path)
|
||||
};
|
||||
|
||||
Self::path_and_leaf_to_opening(merkle_path, leaf)
|
||||
}
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Self::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
|
||||
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
|
||||
|
||||
// if the old value and new value are the same, there is nothing to update
|
||||
if value == old_value {
|
||||
return value;
|
||||
}
|
||||
|
||||
let leaf = self.get_leaf(&key);
|
||||
let node_index = {
|
||||
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
|
||||
leaf_index.into()
|
||||
};
|
||||
|
||||
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
|
||||
|
||||
old_value
|
||||
}
|
||||
|
||||
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
|
||||
/// `node_hash_at_index` is the hash of the node stored at index.
|
||||
fn recompute_nodes_from_index_to_root(
|
||||
&mut self,
|
||||
mut index: NodeIndex,
|
||||
node_hash_at_index: RpoDigest,
|
||||
) {
|
||||
let mut node_hash = node_hash_at_index;
|
||||
for node_depth in (0..index.depth()).rev() {
|
||||
let is_right = index.is_value_odd();
|
||||
index.move_up();
|
||||
let InnerNode { left, right } = self.get_inner_node(index);
|
||||
let (left, right) = if is_right {
|
||||
(left, node_hash)
|
||||
} else {
|
||||
(node_hash, right)
|
||||
};
|
||||
node_hash = Rpo256::merge(&[left, right]);
|
||||
|
||||
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
|
||||
// If a subtree is empty, then can remove the inner node, since it's equal to the
|
||||
// default value
|
||||
self.remove_inner_node(index);
|
||||
} else {
|
||||
self.insert_inner_node(index, InnerNode { left, right });
|
||||
}
|
||||
}
|
||||
self.set_root(node_hash);
|
||||
}
|
||||
|
||||
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
|
||||
/// tree, allowing for validation before applying those changes.
|
||||
///
|
||||
/// This method returns a [`MutationSet`], which contains all the information for inserting
|
||||
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
|
||||
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
|
||||
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
|
||||
/// the Merkle tree, or [`drop()`] to discard them.
|
||||
fn compute_mutations(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
|
||||
self.compute_mutations_sequential(kv_pairs)
|
||||
}
|
||||
|
||||
/// Sequential version of [`SparseMerkleTree::compute_mutations()`].
|
||||
/// This is the default implementation.
|
||||
fn compute_mutations_sequential(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
|
||||
use NodeMutation::*;
|
||||
|
||||
let mut new_root = self.root();
|
||||
let mut new_pairs: UnorderedMap<Self::Key, Self::Value> = Default::default();
|
||||
let mut node_mutations: NodeMutations = Default::default();
|
||||
|
||||
for (key, value) in kv_pairs {
|
||||
// If the old value and the new value are the same, there is nothing to update.
|
||||
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
|
||||
// to check the key-value pairs we've already seen to get the "effective" old value.
|
||||
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
|
||||
if value == old_value {
|
||||
continue;
|
||||
}
|
||||
|
||||
let leaf_index = Self::key_to_leaf_index(&key);
|
||||
let mut node_index = NodeIndex::from(leaf_index);
|
||||
|
||||
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
|
||||
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
|
||||
// part of the "current leaf".
|
||||
let old_leaf = {
|
||||
let pairs_at_index = new_pairs
|
||||
.iter()
|
||||
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
|
||||
|
||||
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
|
||||
// Most of the time `pairs_at_index` should only contain a single entry (or
|
||||
// none at all), as multi-leaves should be really rare.
|
||||
let existing_leaf = acc.clone();
|
||||
self.construct_prospective_leaf(existing_leaf, k, v)
|
||||
})
|
||||
};
|
||||
|
||||
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
|
||||
|
||||
let mut new_child_hash = Self::hash_leaf(&new_leaf);
|
||||
|
||||
for node_depth in (0..node_index.depth()).rev() {
|
||||
// Whether the node we're replacing is the right child or the left child.
|
||||
let is_right = node_index.is_value_odd();
|
||||
node_index.move_up();
|
||||
|
||||
let old_node = node_mutations
|
||||
.get(&node_index)
|
||||
.map(|mutation| match mutation {
|
||||
Addition(node) => node.clone(),
|
||||
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
|
||||
})
|
||||
.unwrap_or_else(|| self.get_inner_node(node_index));
|
||||
|
||||
let new_node = if is_right {
|
||||
InnerNode {
|
||||
left: old_node.left,
|
||||
right: new_child_hash,
|
||||
}
|
||||
} else {
|
||||
InnerNode {
|
||||
left: new_child_hash,
|
||||
right: old_node.right,
|
||||
}
|
||||
};
|
||||
|
||||
// The next iteration will operate on this new node's hash.
|
||||
new_child_hash = new_node.hash();
|
||||
|
||||
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
|
||||
let is_removal = new_child_hash == equivalent_empty_hash;
|
||||
let new_entry = if is_removal { Removal } else { Addition(new_node) };
|
||||
node_mutations.insert(node_index, new_entry);
|
||||
}
|
||||
|
||||
// Once we're at depth 0, the last node we made is the new root.
|
||||
new_root = new_child_hash;
|
||||
// And then we're done with this pair; on to the next one.
|
||||
new_pairs.insert(key, value);
|
||||
}
|
||||
|
||||
MutationSet {
|
||||
old_root: self.root(),
|
||||
new_root,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
|
||||
/// this tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
|
||||
/// the `mutations` were computed against, and the second item is the actual current root of
|
||||
/// this tree.
|
||||
fn apply_mutations(
|
||||
&mut self,
|
||||
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
|
||||
) -> Result<(), MerkleError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
use NodeMutation::*;
|
||||
let MutationSet {
|
||||
old_root,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
new_root,
|
||||
} = mutations;
|
||||
|
||||
// Guard against accidentally trying to apply mutations that were computed against a
|
||||
// different tree, including a stale version of this tree.
|
||||
if old_root != self.root() {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: self.root(),
|
||||
actual_root: old_root,
|
||||
});
|
||||
}
|
||||
|
||||
for (index, mutation) in node_mutations {
|
||||
match mutation {
|
||||
Removal => {
|
||||
self.remove_inner_node(index);
|
||||
},
|
||||
Addition(node) => {
|
||||
self.insert_inner_node(index, node);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for (key, value) in new_pairs {
|
||||
self.insert_value(key, value);
|
||||
}
|
||||
|
||||
self.set_root(new_root);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
|
||||
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
|
||||
/// updated tree will revert the changes.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
|
||||
/// the `mutations` were computed against, and the second item is the actual current root of
|
||||
/// this tree.
|
||||
fn apply_mutations_with_reversion(
|
||||
&mut self,
|
||||
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
|
||||
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
use NodeMutation::*;
|
||||
let MutationSet {
|
||||
old_root,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
new_root,
|
||||
} = mutations;
|
||||
|
||||
// Guard against accidentally trying to apply mutations that were computed against a
|
||||
// different tree, including a stale version of this tree.
|
||||
if old_root != self.root() {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: self.root(),
|
||||
actual_root: old_root,
|
||||
});
|
||||
}
|
||||
|
||||
let mut reverse_mutations = NodeMutations::new();
|
||||
for (index, mutation) in node_mutations {
|
||||
match mutation {
|
||||
Removal => {
|
||||
if let Some(node) = self.remove_inner_node(index) {
|
||||
reverse_mutations.insert(index, Addition(node));
|
||||
}
|
||||
},
|
||||
Addition(node) => {
|
||||
if let Some(old_node) = self.insert_inner_node(index, node) {
|
||||
reverse_mutations.insert(index, Addition(old_node));
|
||||
} else {
|
||||
reverse_mutations.insert(index, Removal);
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
let mut reverse_pairs = UnorderedMap::new();
|
||||
for (key, value) in new_pairs {
|
||||
if let Some(old_value) = self.insert_value(key.clone(), value) {
|
||||
reverse_pairs.insert(key, old_value);
|
||||
} else {
|
||||
reverse_pairs.insert(key, Self::EMPTY_VALUE);
|
||||
}
|
||||
}
|
||||
|
||||
self.set_root(new_root);
|
||||
|
||||
Ok(MutationSet {
|
||||
old_root: new_root,
|
||||
node_mutations: reverse_mutations,
|
||||
new_pairs: reverse_pairs,
|
||||
new_root: old_root,
|
||||
})
|
||||
}
|
||||
|
||||
// REQUIRED METHODS
|
||||
// ---------------------------------------------------------------------------------------------
|
||||
|
||||
/// Construct this type from already computed leaves and nodes. The caller ensures passed
|
||||
/// arguments are correct and consistent with each other.
|
||||
fn from_raw_parts(
|
||||
inner_nodes: InnerNodes,
|
||||
leaves: Leaves<Self::Leaf>,
|
||||
root: RpoDigest,
|
||||
) -> Result<Self, MerkleError>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// The root of the tree
|
||||
fn root(&self) -> RpoDigest;
|
||||
|
||||
/// Sets the root of the tree
|
||||
fn set_root(&mut self, root: RpoDigest);
|
||||
|
||||
/// Retrieves an inner node at the given index
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
|
||||
|
||||
/// Inserts an inner node at the given index
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
|
||||
|
||||
/// Removes an inner node at the given index
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
|
||||
|
||||
/// Inserts a leaf node, and returns the value at the key if already exists
|
||||
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
|
||||
|
||||
/// Returns the value at the specified key. Recall that by definition, any key that hasn't been
|
||||
/// updated is associated with [`Self::EMPTY_VALUE`].
|
||||
fn get_value(&self, key: &Self::Key) -> Self::Value;
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
|
||||
|
||||
/// Returns the hash of a leaf
|
||||
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
|
||||
|
||||
/// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
|
||||
/// mutating the tree itself. The existing leaf can be empty.
|
||||
///
|
||||
/// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
|
||||
/// as the argument for `existing_leaf`. The return value from this function can be chained back
|
||||
/// into this function as the first argument to continue making prospective changes.
|
||||
///
|
||||
/// # Invariants
|
||||
/// Because this method is for a prospective key-value insertion into a specific leaf,
|
||||
/// `existing_leaf` must have the same leaf index as `key` (as determined by
|
||||
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
existing_leaf: Self::Leaf,
|
||||
key: &Self::Key,
|
||||
value: &Self::Value,
|
||||
) -> Self::Leaf;
|
||||
|
||||
/// Maps a key to a leaf index
|
||||
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
|
||||
|
||||
/// Constructs a single leaf from an arbitrary amount of key-value pairs.
|
||||
/// Those pairs must all have the same leaf index.
|
||||
fn pairs_to_leaf(pairs: Vec<(Self::Key, Self::Value)>) -> Self::Leaf;
|
||||
|
||||
/// Maps a (MerklePath, Self::Leaf) to an opening.
|
||||
///
|
||||
/// The length `path` is guaranteed to be equal to `DEPTH`
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
|
||||
}
|
||||
|
||||
// INNER NODE
|
||||
// ================================================================================================
|
||||
|
||||
/// This struct is public so functions returning it can be used in `benches/`, but is otherwise not
|
||||
/// part of the public API.
|
||||
#[doc(hidden)]
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct InnerNode {
|
||||
pub left: RpoDigest,
|
||||
pub right: RpoDigest,
|
||||
}
|
||||
|
||||
impl InnerNode {
|
||||
pub fn hash(&self) -> RpoDigest {
|
||||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// LEAF INDEX
|
||||
// ================================================================================================
|
||||
|
||||
/// The index of a leaf, at a depth known at compile-time.
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct LeafIndex<const DEPTH: u8> {
|
||||
index: NodeIndex,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> LeafIndex<DEPTH> {
|
||||
pub fn new(value: u64) -> Result<Self, MerkleError> {
|
||||
if DEPTH < SMT_MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(DEPTH));
|
||||
}
|
||||
|
||||
Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
|
||||
}
|
||||
|
||||
pub fn value(&self) -> u64 {
|
||||
self.index.value()
|
||||
}
|
||||
}
|
||||
|
||||
impl LeafIndex<SMT_MAX_DEPTH> {
|
||||
pub const fn new_max_depth(value: u64) -> Self {
|
||||
LeafIndex {
|
||||
index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
|
||||
fn from(value: LeafIndex<DEPTH>) -> Self {
|
||||
value.index
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
|
||||
type Error = MerkleError;
|
||||
|
||||
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
|
||||
if node_index.depth() != DEPTH {
|
||||
return Err(MerkleError::InvalidNodeIndexDepth {
|
||||
expected: DEPTH,
|
||||
provided: node_index.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
Self::new(node_index.value())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.index.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
Ok(Self { index: source.read()? })
|
||||
}
|
||||
}
|
||||
|
||||
// MUTATIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// A change to an inner node of a sparse Merkle tree that hasn't yet been applied.
|
||||
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
|
||||
/// need to occur at which node indices.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum NodeMutation {
|
||||
/// Node needs to be removed.
|
||||
Removal,
|
||||
/// Node needs to be inserted.
|
||||
Addition(InnerNode),
|
||||
}
|
||||
|
||||
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
|
||||
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
|
||||
/// `SparseMerkleTree::apply_mutations()`.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
|
||||
/// The root of the Merkle tree this MutationSet is for, recorded at the time
|
||||
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
|
||||
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
|
||||
old_root: RpoDigest,
|
||||
/// The set of nodes that need to be removed or added. The "effective" node at an index is the
|
||||
/// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
|
||||
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
|
||||
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
|
||||
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
|
||||
node_mutations: NodeMutations,
|
||||
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
|
||||
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
|
||||
/// back to the existing value in the Merkle tree. Each entry corresponds to a
|
||||
/// [`SparseMerkleTree::insert_value()`] call.
|
||||
new_pairs: UnorderedMap<K, V>,
|
||||
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
|
||||
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
|
||||
new_root: RpoDigest,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
|
||||
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
|
||||
/// that method for more information.
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.new_root
|
||||
}
|
||||
|
||||
/// Returns the SMT root before the mutations were applied.
|
||||
pub fn old_root(&self) -> RpoDigest {
|
||||
self.old_root
|
||||
}
|
||||
|
||||
/// Returns the set of inner nodes that need to be removed or added.
|
||||
pub fn node_mutations(&self) -> &NodeMutations {
|
||||
&self.node_mutations
|
||||
}
|
||||
|
||||
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
|
||||
/// (i.e. set to `EMPTY_WORD`).
|
||||
pub fn new_pairs(&self) -> &UnorderedMap<K, V> {
|
||||
&self.new_pairs
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for InnerNode {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write(self.left);
|
||||
target.write(self.right);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for InnerNode {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let left = source.read()?;
|
||||
let right = source.read()?;
|
||||
|
||||
Ok(Self { left, right })
|
||||
}
|
||||
}
|
||||
|
||||
impl Serializable for NodeMutation {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
match self {
|
||||
NodeMutation::Removal => target.write_bool(false),
|
||||
NodeMutation::Addition(inner_node) => {
|
||||
target.write_bool(true);
|
||||
inner_node.write_into(target);
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for NodeMutation {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
if source.read_bool()? {
|
||||
let inner_node = source.read()?;
|
||||
return Ok(NodeMutation::Addition(inner_node));
|
||||
}
|
||||
|
||||
Ok(NodeMutation::Removal)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
|
||||
for MutationSet<DEPTH, K, V>
|
||||
{
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write(self.old_root);
|
||||
target.write(self.new_root);
|
||||
|
||||
let inner_removals: Vec<_> = self
|
||||
.node_mutations
|
||||
.iter()
|
||||
.filter(|(_, value)| matches!(value, NodeMutation::Removal))
|
||||
.map(|(key, _)| key)
|
||||
.collect();
|
||||
let inner_additions: Vec<_> = self
|
||||
.node_mutations
|
||||
.iter()
|
||||
.filter_map(|(key, value)| match value {
|
||||
NodeMutation::Addition(node) => Some((key, node)),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
target.write(inner_removals);
|
||||
target.write(inner_additions);
|
||||
|
||||
target.write_usize(self.new_pairs.len());
|
||||
target.write_many(&self.new_pairs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
|
||||
for MutationSet<DEPTH, K, V>
|
||||
{
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let old_root = source.read()?;
|
||||
let new_root = source.read()?;
|
||||
|
||||
let inner_removals: Vec<NodeIndex> = source.read()?;
|
||||
let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
|
||||
|
||||
let node_mutations = NodeMutations::from_iter(
|
||||
inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
|
||||
inner_additions
|
||||
.into_iter()
|
||||
.map(|(index, node)| (index, NodeMutation::Addition(node))),
|
||||
),
|
||||
);
|
||||
|
||||
let num_new_pairs = source.read_usize()?;
|
||||
let new_pairs = source.read_many(num_new_pairs)?;
|
||||
let new_pairs = UnorderedMap::from_iter(new_pairs);
|
||||
|
||||
Ok(Self {
|
||||
old_root,
|
||||
node_mutations,
|
||||
new_pairs,
|
||||
new_root,
|
||||
})
|
||||
}
|
||||
}
|
361
src/merkle/smt/partial.rs
Normal file
361
src/merkle/smt/partial.rs
Normal file
|
@ -0,0 +1,361 @@
|
|||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{smt::SparseMerkleTree, InnerNode, MerkleError, MerklePath, Smt, SmtLeaf, SmtProof},
|
||||
Word, EMPTY_WORD,
|
||||
};
|
||||
|
||||
/// A partial version of an [`Smt`].
|
||||
///
|
||||
/// This type can track a subset of the key-value pairs of a full [`Smt`] and allows for updating
|
||||
/// those pairs to compute the new root of the tree, as if the updates had been done on the full
|
||||
/// tree. This is useful so that not all leaves have to be present and loaded into memory to compute
|
||||
/// an update.
|
||||
///
|
||||
/// To facilitate this, a partial SMT requires that the merkle paths of every key-value pair are
|
||||
/// added to the tree. This means this pair is considered "tracked" and can be updated.
|
||||
///
|
||||
/// An important caveat is that only pairs whose merkle paths were added can be updated. Attempting
|
||||
/// to update an untracked value will result in an error. See [`PartialSmt::insert`] for more
|
||||
/// details.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct PartialSmt(Smt);
|
||||
|
||||
impl PartialSmt {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [`PartialSmt`].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [`Smt::EMPTY_VALUE`].
|
||||
pub fn new() -> Self {
|
||||
Self(Smt::new())
|
||||
}
|
||||
|
||||
/// Instantiates a new [`PartialSmt`] by calling [`PartialSmt::add_path`] for all [`SmtProof`]s
|
||||
/// in the provided iterator.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - the new root after the insertion of a (leaf, path) tuple does not match the existing root
|
||||
/// (except if the tree was previously empty).
|
||||
pub fn from_proofs<I>(paths: I) -> Result<Self, MerkleError>
|
||||
where
|
||||
I: IntoIterator<Item = SmtProof>,
|
||||
{
|
||||
let mut partial_smt = Self::new();
|
||||
|
||||
for (leaf, path) in paths.into_iter().map(SmtProof::into_parts) {
|
||||
partial_smt.add_path(path, leaf)?;
|
||||
}
|
||||
|
||||
Ok(partial_smt)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the root of the tree.
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
self.0.root()
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - the key is not tracked by this partial SMT.
|
||||
pub fn open(&self, key: &RpoDigest) -> Result<SmtProof, MerkleError> {
|
||||
if !self.is_leaf_tracked(key) {
|
||||
return Err(MerkleError::UntrackedKey(*key));
|
||||
}
|
||||
|
||||
Ok(self.0.open(key))
|
||||
}
|
||||
|
||||
/// Returns the leaf to which `key` maps
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - the key is not tracked by this partial SMT.
|
||||
pub fn get_leaf(&self, key: &RpoDigest) -> Result<SmtLeaf, MerkleError> {
|
||||
if !self.is_leaf_tracked(key) {
|
||||
return Err(MerkleError::UntrackedKey(*key));
|
||||
}
|
||||
|
||||
Ok(self.0.get_leaf(key))
|
||||
}
|
||||
|
||||
/// Returns the value associated with `key`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - the key is not tracked by this partial SMT.
|
||||
pub fn get_value(&self, key: &RpoDigest) -> Result<Word, MerkleError> {
|
||||
if !self.is_leaf_tracked(key) {
|
||||
return Err(MerkleError::UntrackedKey(*key));
|
||||
}
|
||||
|
||||
Ok(self.0.get_value(key))
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`Smt::EMPTY_VALUE`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - the key and its merkle path were not previously added (using [`PartialSmt::add_path`]) to
|
||||
/// this [`PartialSmt`], which means it is almost certainly incorrect to update its value. If
|
||||
/// an error is returned the tree is in the same state as before.
|
||||
pub fn insert(&mut self, key: RpoDigest, value: Word) -> Result<Word, MerkleError> {
|
||||
if !self.is_leaf_tracked(&key) {
|
||||
return Err(MerkleError::UntrackedKey(key));
|
||||
}
|
||||
|
||||
let previous_value = self.0.insert(key, value);
|
||||
|
||||
// If the value was removed the SmtLeaf was removed as well by the underlying Smt
|
||||
// implementation. However, we still want to consider that leaf tracked so it can be
|
||||
// read and written to, so we reinsert an empty SmtLeaf.
|
||||
if value == EMPTY_WORD {
|
||||
let leaf_index = Smt::key_to_leaf_index(&key);
|
||||
self.0.leaves.insert(leaf_index.value(), SmtLeaf::Empty(leaf_index));
|
||||
}
|
||||
|
||||
Ok(previous_value)
|
||||
}
|
||||
|
||||
/// Adds a leaf and its merkle path to this [`PartialSmt`] and returns the value that
|
||||
/// was previously present at this key, if any.
|
||||
///
|
||||
/// If this function was called, the `key` can subsequently be updated to a new value and
|
||||
/// produce a correct new tree root.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - the new root after the insertion of the leaf and the path does not match the existing root
|
||||
/// (except if the tree was previously empty). If an error is returned, the tree is left in an
|
||||
/// inconsistent state.
|
||||
pub fn add_path(&mut self, leaf: SmtLeaf, path: MerklePath) -> Result<(), MerkleError> {
|
||||
let mut current_index = leaf.index().index;
|
||||
|
||||
let mut node_hash_at_current_index = leaf.hash();
|
||||
|
||||
// We insert directly into the leaves for two reasons:
|
||||
// - We can directly insert the leaf as it is without having to loop over its entries to
|
||||
// call Smt::perform_insert.
|
||||
// - If the leaf is SmtLeaf::Empty, we will also insert it, which means this leaf is
|
||||
// considered tracked by the partial SMT as it is part of the leaves map. When calling
|
||||
// PartialSmt::insert, this will not error for such empty leaves whose merkle path was
|
||||
// added, but will error for otherwise non-existent leaves whose paths were not added,
|
||||
// which is what we want.
|
||||
self.0.leaves.insert(current_index.value(), leaf);
|
||||
|
||||
for sibling_hash in path {
|
||||
// Find the index of the sibling node and compute whether it is a left or right child.
|
||||
let is_sibling_right = current_index.sibling().is_value_odd();
|
||||
|
||||
// Move the index up so it points to the parent of the current index and the sibling.
|
||||
current_index.move_up();
|
||||
|
||||
// Construct the new parent node from the child that was updated and the sibling from
|
||||
// the merkle path.
|
||||
let new_parent_node = if is_sibling_right {
|
||||
InnerNode {
|
||||
left: node_hash_at_current_index,
|
||||
right: sibling_hash,
|
||||
}
|
||||
} else {
|
||||
InnerNode {
|
||||
left: sibling_hash,
|
||||
right: node_hash_at_current_index,
|
||||
}
|
||||
};
|
||||
|
||||
self.0.insert_inner_node(current_index, new_parent_node);
|
||||
|
||||
node_hash_at_current_index = self.0.get_inner_node(current_index).hash();
|
||||
}
|
||||
|
||||
// Check the newly added merkle path is consistent with the existing tree. If not, the
|
||||
// merkle path was invalid or computed from another tree.
|
||||
// We skip this check if the root is empty since this indicates we're adding the first
|
||||
// merkle path in which case we have to update the tree root to the root from the path.
|
||||
if self.root() != Smt::EMPTY_ROOT && self.root() != node_hash_at_current_index {
|
||||
return Err(MerkleError::ConflictingRoots {
|
||||
expected_root: self.root(),
|
||||
actual_root: node_hash_at_current_index,
|
||||
});
|
||||
}
|
||||
|
||||
self.0.set_root(node_hash_at_current_index);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns true if the key's merkle path was previously added to this partial SMT and can be
|
||||
/// sensibly updated to a new value.
|
||||
/// In particular, this returns true for keys whose value was empty **but** their merkle paths
|
||||
/// were added, while it returns false if the merkle paths were **not** added.
|
||||
fn is_leaf_tracked(&self, key: &RpoDigest) -> bool {
|
||||
self.0.leaves.contains_key(&Smt::key_to_leaf_index(key).value())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PartialSmt {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use rand_utils::rand_array;
|
||||
|
||||
use super::*;
|
||||
use crate::{EMPTY_WORD, ONE, ZERO};
|
||||
|
||||
/// Tests that a basic PartialSmt can be built from a full one and that inserting or removing
|
||||
/// values whose merkle path were added to the partial SMT results in the same root as the
|
||||
/// equivalent update in the full tree.
|
||||
#[test]
|
||||
fn partial_smt_insert_and_remove() {
|
||||
let key0 = RpoDigest::from(Word::from(rand_array()));
|
||||
let key1 = RpoDigest::from(Word::from(rand_array()));
|
||||
let key2 = RpoDigest::from(Word::from(rand_array()));
|
||||
// A key for which we won't add a value so it will be empty.
|
||||
let key_empty = RpoDigest::from(Word::from(rand_array()));
|
||||
|
||||
let value0 = Word::from(rand_array());
|
||||
let value1 = Word::from(rand_array());
|
||||
let value2 = Word::from(rand_array());
|
||||
|
||||
let mut kv_pairs = vec![(key0, value0), (key1, value1), (key2, value2)];
|
||||
|
||||
// Add more random leaves.
|
||||
kv_pairs.reserve(1000);
|
||||
for _ in 0..1000 {
|
||||
let key = RpoDigest::from(Word::from(rand_array()));
|
||||
let value = Word::from(rand_array());
|
||||
kv_pairs.push((key, value));
|
||||
}
|
||||
|
||||
let mut full = Smt::with_entries(kv_pairs).unwrap();
|
||||
|
||||
// Constructing a partial SMT from proofs succeeds.
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let proof0 = full.open(&key0);
|
||||
let proof2 = full.open(&key2);
|
||||
let proof_empty = full.open(&key_empty);
|
||||
|
||||
assert!(proof_empty.leaf().is_empty());
|
||||
|
||||
let mut partial = PartialSmt::from_proofs([proof0, proof2, proof_empty]).unwrap();
|
||||
|
||||
assert_eq!(full.root(), partial.root());
|
||||
assert_eq!(partial.get_value(&key0).unwrap(), value0);
|
||||
let error = partial.get_value(&key1).unwrap_err();
|
||||
assert_matches!(error, MerkleError::UntrackedKey(_));
|
||||
assert_eq!(partial.get_value(&key2).unwrap(), value2);
|
||||
|
||||
// Insert new values for added keys with empty and non-empty values.
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let new_value0 = Word::from(rand_array());
|
||||
let new_value2 = Word::from(rand_array());
|
||||
// A non-empty value for the key that was previously empty.
|
||||
let new_value_empty_key = Word::from(rand_array());
|
||||
|
||||
full.insert(key0, new_value0);
|
||||
full.insert(key2, new_value2);
|
||||
full.insert(key_empty, new_value_empty_key);
|
||||
|
||||
partial.insert(key0, new_value0).unwrap();
|
||||
partial.insert(key2, new_value2).unwrap();
|
||||
// This updates a key whose value was previously empty.
|
||||
partial.insert(key_empty, new_value_empty_key).unwrap();
|
||||
|
||||
assert_eq!(full.root(), partial.root());
|
||||
assert_eq!(partial.get_value(&key0).unwrap(), new_value0);
|
||||
assert_eq!(partial.get_value(&key2).unwrap(), new_value2);
|
||||
assert_eq!(partial.get_value(&key_empty).unwrap(), new_value_empty_key);
|
||||
|
||||
// Remove an added key.
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
full.insert(key0, EMPTY_WORD);
|
||||
partial.insert(key0, EMPTY_WORD).unwrap();
|
||||
|
||||
assert_eq!(full.root(), partial.root());
|
||||
assert_eq!(partial.get_value(&key0).unwrap(), EMPTY_WORD);
|
||||
|
||||
// Check if returned openings are the same in partial and full SMT.
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
// This is a key whose value is empty since it was removed.
|
||||
assert_eq!(full.open(&key0), partial.open(&key0).unwrap());
|
||||
// This is a key whose value is non-empty.
|
||||
assert_eq!(full.open(&key2), partial.open(&key2).unwrap());
|
||||
|
||||
// Attempting to update a key whose merkle path was not added is an error.
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
let error = partial.clone().insert(key1, Word::from(rand_array())).unwrap_err();
|
||||
assert_matches!(error, MerkleError::UntrackedKey(_));
|
||||
|
||||
let error = partial.insert(key1, EMPTY_WORD).unwrap_err();
|
||||
assert_matches!(error, MerkleError::UntrackedKey(_));
|
||||
}
|
||||
|
||||
/// Test that we can add an SmtLeaf::Multiple variant to a partial SMT.
|
||||
#[test]
|
||||
fn partial_smt_multiple_leaf_success() {
|
||||
// key0 and key1 have the same felt at index 3 so they will be placed in the same leaf.
|
||||
let key0 = RpoDigest::from(Word::from([ZERO, ZERO, ZERO, ONE]));
|
||||
let key1 = RpoDigest::from(Word::from([ONE, ONE, ONE, ONE]));
|
||||
let key2 = RpoDigest::from(Word::from(rand_array()));
|
||||
|
||||
let value0 = Word::from(rand_array());
|
||||
let value1 = Word::from(rand_array());
|
||||
let value2 = Word::from(rand_array());
|
||||
|
||||
let full = Smt::with_entries([(key0, value0), (key1, value1), (key2, value2)]).unwrap();
|
||||
|
||||
// Make sure our assumption about the leaf being a multiple is correct.
|
||||
let SmtLeaf::Multiple(_) = full.get_leaf(&key0) else {
|
||||
panic!("expected full tree to produce multiple leaf")
|
||||
};
|
||||
|
||||
let proof0 = full.open(&key0);
|
||||
let proof2 = full.open(&key2);
|
||||
|
||||
let partial = PartialSmt::from_proofs([proof0, proof2]).unwrap();
|
||||
|
||||
assert_eq!(partial.root(), full.root());
|
||||
|
||||
assert_eq!(partial.get_leaf(&key0).unwrap(), full.get_leaf(&key0));
|
||||
// key1 is present in the partial tree because it is part of the proof of key0.
|
||||
assert_eq!(partial.get_leaf(&key1).unwrap(), full.get_leaf(&key1));
|
||||
assert_eq!(partial.get_leaf(&key2).unwrap(), full.get_leaf(&key2));
|
||||
}
|
||||
}
|
425
src/merkle/smt/simple/mod.rs
Normal file
425
src/merkle/smt/simple/mod.rs
Normal file
|
@ -0,0 +1,425 @@
|
|||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
|
||||
use super::{
|
||||
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex,
|
||||
MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// SPARSE MERKLE TREE
|
||||
// ================================================================================================
|
||||
|
||||
type Leaves = super::Leaves<Word>;
|
||||
|
||||
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
|
||||
///
|
||||
/// The root of the tree is recomputed on each new leaf update.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct SimpleSmt<const DEPTH: u8> {
|
||||
root: RpoDigest,
|
||||
inner_nodes: InnerNodes,
|
||||
leaves: Leaves,
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
|
||||
// CONSTANTS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// The default value used to compute the hash of empty leaves
|
||||
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
|
||||
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a new [SimpleSmt].
|
||||
///
|
||||
/// All leaves in the returned tree are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if DEPTH is 0 or is greater than 64.
|
||||
pub fn new() -> Result<Self, MerkleError> {
|
||||
// validate the range of the depth.
|
||||
if DEPTH < SMT_MIN_DEPTH {
|
||||
return Err(MerkleError::DepthTooSmall(DEPTH));
|
||||
} else if SMT_MAX_DEPTH < DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(DEPTH as u64));
|
||||
}
|
||||
|
||||
let root = *EmptySubtreeRoots::entry(DEPTH, 0);
|
||||
|
||||
Ok(Self {
|
||||
root,
|
||||
inner_nodes: Default::default(),
|
||||
leaves: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
|
||||
///
|
||||
/// All leaves omitted from the entries list are set to [ZERO; 4].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if:
|
||||
/// - If the depth is 0 or is greater than 64.
|
||||
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
|
||||
/// - The provided entries contain multiple values for the same key.
|
||||
pub fn with_leaves(
|
||||
entries: impl IntoIterator<Item = (u64, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
// create an empty tree
|
||||
let mut tree = Self::new()?;
|
||||
|
||||
// compute the max number of entries. We use an upper bound of depth 63 because we consider
|
||||
// passing in a vector of size 2^64 infeasible.
|
||||
let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
|
||||
|
||||
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
|
||||
// entries with the empty value need additional tracking.
|
||||
let mut key_set_to_zero = BTreeSet::new();
|
||||
|
||||
for (idx, (key, value)) in entries.into_iter().enumerate() {
|
||||
if idx >= max_num_entries {
|
||||
return Err(MerkleError::TooManyEntries(max_num_entries));
|
||||
}
|
||||
|
||||
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
|
||||
|
||||
if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
|
||||
return Err(MerkleError::DuplicateValuesForIndex(key));
|
||||
}
|
||||
|
||||
if value == Self::EMPTY_VALUE {
|
||||
key_set_to_zero.insert(key);
|
||||
};
|
||||
}
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
|
||||
///
|
||||
/// This function performs minimal consistency checking. It is the caller's responsibility to
|
||||
/// ensure the passed arguments are correct and consistent with each other.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if `root` does not match the root node in
|
||||
/// `inner_nodes`.
|
||||
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
|
||||
// Our particular implementation of `from_raw_parts()` never returns `Err`.
|
||||
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
|
||||
}
|
||||
|
||||
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
|
||||
/// starting at index 0.
|
||||
pub fn with_contiguous_leaves(
|
||||
entries: impl IntoIterator<Item = Word>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
Self::with_leaves(
|
||||
entries
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
|
||||
)
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the tree
|
||||
pub const fn depth(&self) -> u8 {
|
||||
DEPTH
|
||||
}
|
||||
|
||||
/// Returns the root of the tree
|
||||
pub fn root(&self) -> RpoDigest {
|
||||
<Self as SparseMerkleTree<DEPTH>>::root(self)
|
||||
}
|
||||
|
||||
/// Returns the number of non-empty leaves in this tree.
|
||||
pub fn num_leaves(&self) -> usize {
|
||||
self.leaves.len()
|
||||
}
|
||||
|
||||
/// Returns the leaf at the specified index.
|
||||
pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
|
||||
}
|
||||
|
||||
/// Returns a node at the specified index.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
|
||||
/// the depth of this Merkle tree.
|
||||
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
if index.is_root() {
|
||||
Err(MerkleError::DepthTooSmall(index.depth()))
|
||||
} else if index.depth() > DEPTH {
|
||||
Err(MerkleError::DepthTooBig(index.depth() as u64))
|
||||
} else if index.depth() == DEPTH {
|
||||
let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
|
||||
|
||||
Ok(leaf.into())
|
||||
} else {
|
||||
Ok(self.get_inner_node(index).hash())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||
/// path to the leaf, as well as the leaf itself.
|
||||
pub fn open(&self, key: &LeafIndex<DEPTH>) -> ValuePath {
|
||||
<Self as SparseMerkleTree<DEPTH>>::open(self, key)
|
||||
}
|
||||
|
||||
/// Returns a boolean value indicating whether the SMT is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
|
||||
self.root == Self::EMPTY_ROOT
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the leaves of this [SimpleSmt].
|
||||
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
|
||||
self.leaves.iter().map(|(i, w)| (*i, w))
|
||||
}
|
||||
|
||||
/// Returns an iterator over the inner nodes of this [SimpleSmt].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.inner_nodes.values().map(|e| InnerNodeInfo {
|
||||
value: e.hash(),
|
||||
left: e.left,
|
||||
right: e.right,
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a value at the specified key, returning the previous value associated with that key.
|
||||
/// Recall that by definition, any key that hasn't been updated is associated with
|
||||
/// [`EMPTY_WORD`].
|
||||
///
|
||||
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
|
||||
/// updating the root itself.
|
||||
pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
|
||||
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
|
||||
}
|
||||
|
||||
/// Computes what changes are necessary to insert the specified key-value pairs into this
|
||||
/// Merkle tree, allowing for validation before applying those changes.
|
||||
///
|
||||
/// This method returns a [`MutationSet`], which contains all the information for inserting
|
||||
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
|
||||
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
|
||||
/// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
|
||||
/// Merkle tree, or [`drop()`] to discard them.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
|
||||
/// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
|
||||
/// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
|
||||
/// let pair = (LeafIndex::default(), Word::default());
|
||||
/// let mutations = smt.compute_mutations(vec![pair]);
|
||||
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
|
||||
/// smt.apply_mutations(mutations);
|
||||
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
|
||||
/// ```
|
||||
pub fn compute_mutations(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
|
||||
) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
|
||||
}
|
||||
|
||||
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
|
||||
/// tree.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
|
||||
/// root hash the `mutations` were computed against, and the second item is the actual
|
||||
/// current root of this tree.
|
||||
pub fn apply_mutations(
|
||||
&mut self,
|
||||
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
|
||||
) -> Result<(), MerkleError> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
|
||||
}
|
||||
|
||||
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
|
||||
/// this tree and returns the reverse mutation set.
|
||||
///
|
||||
/// Applying the reverse mutation sets to the updated tree will revert the changes.
|
||||
///
|
||||
/// # Errors
|
||||
/// If `mutations` was computed on a tree with a different root than this one, returns
|
||||
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
|
||||
/// root hash the `mutations` were computed against, and the second item is the actual
|
||||
/// current root of this tree.
|
||||
pub fn apply_mutations_with_reversion(
|
||||
&mut self,
|
||||
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
|
||||
) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
|
||||
<Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
|
||||
}
|
||||
|
||||
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
|
||||
/// computed as `DEPTH - SUBTREE_DEPTH`.
|
||||
///
|
||||
/// Returns the new root.
|
||||
pub fn set_subtree<const SUBTREE_DEPTH: u8>(
|
||||
&mut self,
|
||||
subtree_insertion_index: u64,
|
||||
subtree: SimpleSmt<SUBTREE_DEPTH>,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
if SUBTREE_DEPTH > DEPTH {
|
||||
return Err(MerkleError::SubtreeDepthExceedsDepth {
|
||||
subtree_depth: SUBTREE_DEPTH,
|
||||
tree_depth: DEPTH,
|
||||
});
|
||||
}
|
||||
|
||||
// Verify that `subtree_insertion_index` is valid.
|
||||
let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
|
||||
let subtree_root_index =
|
||||
NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
|
||||
|
||||
// add leaves
|
||||
// --------------
|
||||
|
||||
// The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
|
||||
// insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
|
||||
// valid as they are. However, consider what happens when we insert at
|
||||
// `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
|
||||
// you can see it as there's a full subtree sitting on its left. In general, for
|
||||
// `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
|
||||
// to insert, so we need to adjust all its leaves by `i * 2^d`.
|
||||
let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
|
||||
for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
|
||||
let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
|
||||
debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
|
||||
|
||||
self.leaves.insert(new_leaf_idx, *leaf_value);
|
||||
}
|
||||
|
||||
// add subtree's branch nodes (which includes the root)
|
||||
// --------------
|
||||
for (branch_idx, branch_node) in subtree.inner_nodes {
|
||||
let new_branch_idx = {
|
||||
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
|
||||
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
|
||||
+ branch_idx.value();
|
||||
|
||||
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
|
||||
};
|
||||
|
||||
self.inner_nodes.insert(new_branch_idx, branch_node);
|
||||
}
|
||||
|
||||
// recompute nodes starting from subtree root
|
||||
// --------------
|
||||
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
|
||||
|
||||
Ok(self.root)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
type Key = LeafIndex<DEPTH>;
|
||||
type Value = Word;
|
||||
type Leaf = Word;
|
||||
type Opening = ValuePath;
|
||||
|
||||
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
|
||||
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
|
||||
|
||||
fn from_raw_parts(
|
||||
inner_nodes: InnerNodes,
|
||||
leaves: Leaves,
|
||||
root: RpoDigest,
|
||||
) -> Result<Self, MerkleError> {
|
||||
if cfg!(debug_assertions) {
|
||||
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(root_node.hash(), root);
|
||||
}
|
||||
|
||||
Ok(Self { root, inner_nodes, leaves })
|
||||
}
|
||||
|
||||
fn root(&self) -> RpoDigest {
|
||||
self.root
|
||||
}
|
||||
|
||||
fn set_root(&mut self, root: RpoDigest) {
|
||||
self.root = root;
|
||||
}
|
||||
|
||||
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
|
||||
self.inner_nodes
|
||||
.get(&index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
|
||||
}
|
||||
|
||||
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
|
||||
self.inner_nodes.insert(index, inner_node)
|
||||
}
|
||||
|
||||
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
|
||||
self.inner_nodes.remove(&index)
|
||||
}
|
||||
|
||||
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
|
||||
if value == Self::EMPTY_VALUE {
|
||||
self.leaves.remove(&key.value())
|
||||
} else {
|
||||
self.leaves.insert(key.value(), value)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
self.get_leaf(key)
|
||||
}
|
||||
|
||||
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
|
||||
let leaf_pos = key.value();
|
||||
match self.leaves.get(&leaf_pos) {
|
||||
Some(word) => *word,
|
||||
None => Self::EMPTY_VALUE,
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_leaf(leaf: &Word) -> RpoDigest {
|
||||
// `SimpleSmt` takes the leaf value itself as the hash
|
||||
leaf.into()
|
||||
}
|
||||
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
_existing_leaf: Word,
|
||||
_key: &LeafIndex<DEPTH>,
|
||||
value: &Word,
|
||||
) -> Word {
|
||||
*value
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
|
||||
*key
|
||||
}
|
||||
|
||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
|
||||
(path, leaf).into()
|
||||
}
|
||||
|
||||
fn pairs_to_leaf(mut pairs: Vec<(LeafIndex<DEPTH>, Word)>) -> Word {
|
||||
// SimpleSmt can't have more than one value per key.
|
||||
assert_eq!(pairs.len(), 1);
|
||||
let (_key, value) = pairs.pop().unwrap();
|
||||
value
|
||||
}
|
||||
}
|
478
src/merkle/smt/simple/tests.rs
Normal file
478
src/merkle/smt/simple/tests.rs
Normal file
|
@ -0,0 +1,478 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
|
||||
use super::{
|
||||
super::{MerkleError, RpoDigest, SimpleSmt},
|
||||
NodeIndex,
|
||||
};
|
||||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
|
||||
InnerNodeInfo, LeafIndex, MerkleTree,
|
||||
},
|
||||
Word, EMPTY_WORD,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const KEYS4: [u64; 4] = [0, 1, 2, 3];
|
||||
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
|
||||
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn build_empty_tree() {
|
||||
// tree of depth 3
|
||||
let smt = SimpleSmt::<3>::new().unwrap();
|
||||
let mt = MerkleTree::new(ZERO_VALUES8).unwrap();
|
||||
assert_eq!(mt.root(), smt.root());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_sparse_tree() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut smt = SimpleSmt::<DEPTH>::new().unwrap();
|
||||
let mut values = ZERO_VALUES8.to_vec();
|
||||
|
||||
assert_eq!(smt.num_leaves(), 0);
|
||||
|
||||
// insert single value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(7);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
|
||||
let mt2 = MerkleTree::new(values.clone()).unwrap();
|
||||
assert_eq!(mt2.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
|
||||
smt.open(&LeafIndex::<3>::new(6).unwrap()).path
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
assert_eq!(smt.num_leaves(), 1);
|
||||
|
||||
// insert second value at distinct leaf branch
|
||||
let key = 2;
|
||||
let new_node = int_to_leaf(3);
|
||||
values[key as usize] = new_node;
|
||||
let old_value = smt.insert(LeafIndex::<DEPTH>::new(key).unwrap(), new_node);
|
||||
let mt3 = MerkleTree::new(values).unwrap();
|
||||
assert_eq!(mt3.root(), smt.root());
|
||||
assert_eq!(
|
||||
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
|
||||
smt.open(&LeafIndex::<3>::new(2).unwrap()).path
|
||||
);
|
||||
assert_eq!(old_value, EMPTY_WORD);
|
||||
assert_eq!(smt.num_leaves(), 2);
|
||||
}
|
||||
|
||||
/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected
|
||||
#[test]
|
||||
fn build_contiguous_tree() {
|
||||
let tree_with_leaves =
|
||||
SimpleSmt::<2>::with_leaves([0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4)))
|
||||
.unwrap();
|
||||
|
||||
let tree_with_contiguous_leaves =
|
||||
SimpleSmt::<2>::with_contiguous_leaves(digests_to_words(&VALUES4)).unwrap();
|
||||
|
||||
assert_eq!(tree_with_leaves, tree_with_contiguous_leaves);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_depth2_tree() {
|
||||
let tree =
|
||||
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
|
||||
// check internal structure
|
||||
let (root, node2, node3) = compute_internal_nodes();
|
||||
assert_eq!(root, tree.root());
|
||||
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
|
||||
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
|
||||
|
||||
// check get_node()
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// check get_path(): depth 2
|
||||
assert_eq!(vec![VALUES4[1], node3], *tree.open(&LeafIndex::<2>::new(0).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[0], node3], *tree.open(&LeafIndex::<2>::new(1).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[3], node2], *tree.open(&LeafIndex::<2>::new(2).unwrap()).path);
|
||||
assert_eq!(vec![VALUES4[2], node2], *tree.open(&LeafIndex::<2>::new(3).unwrap()).path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inner_node_iterator() -> Result<(), MerkleError> {
|
||||
let tree =
|
||||
SimpleSmt::<2>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
|
||||
// check depth 2
|
||||
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
|
||||
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
|
||||
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
|
||||
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
|
||||
|
||||
// get parent nodes
|
||||
let root = tree.root();
|
||||
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
|
||||
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
|
||||
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
|
||||
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
|
||||
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
|
||||
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
|
||||
|
||||
let mut nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
|
||||
let mut expected = [
|
||||
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
|
||||
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
|
||||
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
|
||||
];
|
||||
nodes.sort();
|
||||
expected.sort();
|
||||
|
||||
assert_eq!(nodes, expected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert() {
|
||||
const DEPTH: u8 = 3;
|
||||
let mut tree =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS8.into_iter().zip(digests_to_words(&VALUES8))).unwrap();
|
||||
assert_eq!(tree.num_leaves(), 8);
|
||||
|
||||
// update one value
|
||||
let key = 3;
|
||||
let new_node = int_to_leaf(9);
|
||||
let mut expected_values = digests_to_words(&VALUES8);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
assert_eq!(tree.num_leaves(), 8);
|
||||
|
||||
// update another value
|
||||
let key = 6;
|
||||
let new_node = int_to_leaf(10);
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
assert_eq!(tree.num_leaves(), 8);
|
||||
|
||||
// set a leaf to empty value
|
||||
let key = 5;
|
||||
let new_node = EMPTY_WORD;
|
||||
expected_values[key] = new_node;
|
||||
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
|
||||
|
||||
let old_leaf = tree.insert(LeafIndex::<DEPTH>::new(key as u64).unwrap(), new_node);
|
||||
assert_eq!(expected_tree.root(), tree.root);
|
||||
assert_eq!(old_leaf, *VALUES8[key]);
|
||||
assert_eq!(tree.num_leaves(), 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_tree_opening_is_consistent() {
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
let tree = SimpleSmt::<3>::with_leaves(entries).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
|
||||
let cases: Vec<(u64, Vec<RpoDigest>)> = vec![
|
||||
(0, vec![b.into(), f, j]),
|
||||
(1, vec![a.into(), f, j]),
|
||||
(4, vec![z.into(), h, i]),
|
||||
(7, vec![z.into(), g, i]),
|
||||
];
|
||||
|
||||
for (key, path) in cases {
|
||||
let opening = tree.open(&LeafIndex::<3>::new(key).unwrap());
|
||||
|
||||
assert_eq!(path, *opening.path);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_fail_on_duplicates() {
|
||||
let values = [
|
||||
// same key, same value
|
||||
(int_to_leaf(1), int_to_leaf(1)),
|
||||
// same key, different values
|
||||
(int_to_leaf(1), int_to_leaf(2)),
|
||||
// same key, set to zero
|
||||
(EMPTY_WORD, int_to_leaf(1)),
|
||||
// same key, re-set to zero
|
||||
(int_to_leaf(1), EMPTY_WORD),
|
||||
// same key, set to zero twice
|
||||
(EMPTY_WORD, EMPTY_WORD),
|
||||
];
|
||||
|
||||
for (first, second) in values.iter() {
|
||||
// consecutive
|
||||
let entries = [(1, *first), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
|
||||
// not consecutive
|
||||
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn with_no_duplicates_empty_node() {
|
||||
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
|
||||
let smt = SimpleSmt::<64>::with_leaves(entries);
|
||||
assert!(smt.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_with_leaves_nonexisting_leaf() {
|
||||
// TESTING WITH EMPTY WORD
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
// Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<1>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<2>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, EMPTY_WORD)];
|
||||
let result = SimpleSmt::<3>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// TESTING WITH A VALUE
|
||||
// --------------------------------------------------------------------------------------------
|
||||
let value = int_to_node(1);
|
||||
|
||||
// Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist.
|
||||
let leaves = [(2, *value)];
|
||||
let result = SimpleSmt::<1>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist.
|
||||
let leaves = [(4, *value)];
|
||||
let result = SimpleSmt::<2>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist.
|
||||
let leaves = [(8, *value)];
|
||||
let result = SimpleSmt::<3>::with_leaves(leaves);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
let e = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let f = Rpo256::merge(&[z.into(), z.into()]);
|
||||
let g = Rpo256::merge(&[c.into(), z.into()]);
|
||||
let h = Rpo256::merge(&[z.into(), d.into()]);
|
||||
|
||||
let i = Rpo256::merge(&[e, f]);
|
||||
let j = Rpo256::merge(&[g, h]);
|
||||
|
||||
let k = Rpo256::merge(&[i, j]);
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::<1>::with_leaves(entries).unwrap()
|
||||
};
|
||||
|
||||
// insert subtree
|
||||
const TREE_DEPTH: u8 = 3;
|
||||
let tree = {
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
let mut tree = SimpleSmt::<TREE_DEPTH>::with_leaves(entries).unwrap();
|
||||
|
||||
tree.set_subtree(2, subtree).unwrap();
|
||||
|
||||
tree
|
||||
};
|
||||
|
||||
assert_eq!(tree.root(), k);
|
||||
assert_eq!(tree.get_leaf(&LeafIndex::<TREE_DEPTH>::new(4).unwrap()), c);
|
||||
assert_eq!(tree.get_inner_node(NodeIndex::new_unchecked(2, 2)).hash(), g);
|
||||
}
|
||||
|
||||
/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_unchanged_for_wrong_index() {
|
||||
// Final Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
// subtree:
|
||||
// g
|
||||
// / \
|
||||
// c 0
|
||||
let subtree = {
|
||||
let entries = vec![(0, c)];
|
||||
SimpleSmt::<1>::with_leaves(entries).unwrap()
|
||||
};
|
||||
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (7, d)];
|
||||
SimpleSmt::<3>::with_leaves(entries).unwrap()
|
||||
};
|
||||
let tree_root_before_insertion = tree.root();
|
||||
|
||||
// insert subtree
|
||||
assert!(tree.set_subtree(500, subtree).is_err());
|
||||
|
||||
assert_eq!(tree.root(), tree_root_before_insertion);
|
||||
}
|
||||
|
||||
/// We insert an empty subtree that has the same depth as the original tree
|
||||
#[test]
|
||||
fn test_simplesmt_set_subtree_entire_tree() {
|
||||
// Initial Tree:
|
||||
//
|
||||
// ____k____
|
||||
// / \
|
||||
// _i_ _j_
|
||||
// / \ / \
|
||||
// e f g h
|
||||
// / \ / \ / \ / \
|
||||
// a b 0 0 c 0 0 d
|
||||
|
||||
let z = EMPTY_WORD;
|
||||
|
||||
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
|
||||
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
|
||||
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
|
||||
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
|
||||
|
||||
// subtree: E3
|
||||
const DEPTH: u8 = 3;
|
||||
let subtree = { SimpleSmt::<DEPTH>::with_leaves(Vec::new()).unwrap() };
|
||||
assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
|
||||
// insert subtree
|
||||
let mut tree = {
|
||||
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
|
||||
SimpleSmt::<3>::with_leaves(entries).unwrap()
|
||||
};
|
||||
|
||||
tree.set_subtree(0, subtree).unwrap();
|
||||
|
||||
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
|
||||
}
|
||||
|
||||
/// Tests that `EMPTY_ROOT` constant generated in the `SimpleSmt` equals to the root of the empty
|
||||
/// tree of depth 64
|
||||
#[test]
|
||||
fn test_simplesmt_check_empty_root_constant() {
|
||||
// get the root of the empty tree of depth 64
|
||||
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
assert_eq!(empty_root_64_depth, SimpleSmt::<64>::EMPTY_ROOT);
|
||||
|
||||
// get the root of the empty tree of depth 32
|
||||
let empty_root_32_depth = EmptySubtreeRoots::empty_hashes(32)[0];
|
||||
assert_eq!(empty_root_32_depth, SimpleSmt::<32>::EMPTY_ROOT);
|
||||
|
||||
// get the root of the empty tree of depth 0
|
||||
let empty_root_1_depth = EmptySubtreeRoots::empty_hashes(1)[0];
|
||||
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
|
||||
let node2 = Rpo256::merge(&[VALUES4[0], VALUES4[1]]);
|
||||
let node3 = Rpo256::merge(&[VALUES4[2], VALUES4[3]]);
|
||||
let root = Rpo256::merge(&[node2, node3]);
|
||||
|
||||
(root, node2, node3)
|
||||
}
|
627
src/merkle/store/mod.rs
Normal file
627
src/merkle/store/mod.rs
Normal file
|
@ -0,0 +1,627 @@
|
|||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
use core::borrow::Borrow;
|
||||
|
||||
use super::{
|
||||
mmr::Mmr, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTree, NodeIndex,
|
||||
PartialMerkleTree, RootPath, Rpo256, RpoDigest, SimpleSmt, Smt, ValuePath,
|
||||
};
|
||||
use crate::utils::{
|
||||
collections::{KvMap, RecordingMap},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// MERKLE STORE
|
||||
// ================================================================================================
|
||||
|
||||
/// A default [MerkleStore] which uses a simple [BTreeMap] as the backing storage.
|
||||
pub type DefaultMerkleStore = MerkleStore<BTreeMap<RpoDigest, StoreNode>>;
|
||||
|
||||
/// A [MerkleStore] with recording capabilities which uses [RecordingMap] as the backing storage.
|
||||
pub type RecordingMerkleStore = MerkleStore<RecordingMap<RpoDigest, StoreNode>>;
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct StoreNode {
|
||||
left: RpoDigest,
|
||||
right: RpoDigest,
|
||||
}
|
||||
|
||||
/// An in-memory data store for Merkelized data.
|
||||
///
|
||||
/// This is a in memory data store for Merkle trees, this store allows all the nodes of multiple
|
||||
/// trees to live as long as necessary and without duplication, this allows the implementation of
|
||||
/// space efficient persistent data structures.
|
||||
///
|
||||
/// Example usage:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use miden_crypto::{ZERO, Felt, Word};
|
||||
/// # use miden_crypto::merkle::{NodeIndex, MerkleStore, MerkleTree};
|
||||
/// # use miden_crypto::hash::rpo::Rpo256;
|
||||
/// # const fn int_to_node(value: u64) -> Word {
|
||||
/// # [Felt::new(value), ZERO, ZERO, ZERO]
|
||||
/// # }
|
||||
/// # let A = int_to_node(1);
|
||||
/// # let B = int_to_node(2);
|
||||
/// # let C = int_to_node(3);
|
||||
/// # let D = int_to_node(4);
|
||||
/// # let E = int_to_node(5);
|
||||
/// # let F = int_to_node(6);
|
||||
/// # let G = int_to_node(7);
|
||||
/// # let H0 = int_to_node(8);
|
||||
/// # let H1 = int_to_node(9);
|
||||
/// # let T0 = MerkleTree::new([A, B, C, D, E, F, G, H0].to_vec()).expect("even number of leaves provided");
|
||||
/// # let T1 = MerkleTree::new([A, B, C, D, E, F, G, H1].to_vec()).expect("even number of leaves provided");
|
||||
/// # let ROOT0 = T0.root();
|
||||
/// # let ROOT1 = T1.root();
|
||||
/// let mut store: MerkleStore = MerkleStore::new();
|
||||
///
|
||||
/// // the store is initialized with the SMT empty nodes
|
||||
/// assert_eq!(store.num_internal_nodes(), 255);
|
||||
///
|
||||
/// let tree1 = MerkleTree::new(vec![A, B, C, D, E, F, G, H0]).unwrap();
|
||||
/// let tree2 = MerkleTree::new(vec![A, B, C, D, E, F, G, H1]).unwrap();
|
||||
///
|
||||
/// // populates the store with two merkle trees, common nodes are shared
|
||||
/// store.extend(tree1.inner_nodes());
|
||||
/// store.extend(tree2.inner_nodes());
|
||||
///
|
||||
/// // every leaf except the last are the same
|
||||
/// for i in 0..7 {
|
||||
/// let idx0 = NodeIndex::new(3, i).unwrap();
|
||||
/// let d0 = store.get_node(ROOT0, idx0).unwrap();
|
||||
/// let idx1 = NodeIndex::new(3, i).unwrap();
|
||||
/// let d1 = store.get_node(ROOT1, idx1).unwrap();
|
||||
/// assert_eq!(d0, d1, "Both trees have the same leaf at pos {i}");
|
||||
/// }
|
||||
///
|
||||
/// // The leafs A-B-C-D are the same for both trees, so are their 2 immediate parents
|
||||
/// for i in 0..4 {
|
||||
/// let idx0 = NodeIndex::new(3, i).unwrap();
|
||||
/// let d0 = store.get_path(ROOT0, idx0).unwrap();
|
||||
/// let idx1 = NodeIndex::new(3, i).unwrap();
|
||||
/// let d1 = store.get_path(ROOT1, idx1).unwrap();
|
||||
/// assert_eq!(d0.path[0..2], d1.path[0..2], "Both sub-trees are equal up to two levels");
|
||||
/// }
|
||||
///
|
||||
/// // Common internal nodes are shared, the two added trees have a total of 30, but the store has
|
||||
/// // only 10 new entries, corresponding to the 10 unique internal nodes of these trees.
|
||||
/// assert_eq!(store.num_internal_nodes() - 255, 10);
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct MerkleStore<T: KvMap<RpoDigest, StoreNode> = BTreeMap<RpoDigest, StoreNode>> {
|
||||
nodes: T,
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Default for MerkleStore<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
||||
// CONSTRUCTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Creates an empty `MerkleStore` instance.
|
||||
pub fn new() -> MerkleStore<T> {
|
||||
// pre-populate the store with the empty hashes
|
||||
let nodes = empty_hashes().into_iter().collect();
|
||||
MerkleStore { nodes }
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Return a count of the non-leaf nodes in the store.
|
||||
pub fn num_internal_nodes(&self) -> usize {
|
||||
self.nodes.len()
|
||||
}
|
||||
|
||||
/// Returns the node at `index` rooted on the tree `root`.
|
||||
///
|
||||
/// # Errors
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
|
||||
/// store.
|
||||
pub fn get_node(&self, root: RpoDigest, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
|
||||
let mut hash = root;
|
||||
|
||||
// corner case: check the root is in the store when called with index `NodeIndex::root()`
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
|
||||
|
||||
for i in (0..index.depth()).rev() {
|
||||
let node = self
|
||||
.nodes
|
||||
.get(&hash)
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
|
||||
|
||||
let bit = (index.value() >> i) & 1;
|
||||
hash = if bit == 0 { node.left } else { node.right }
|
||||
}
|
||||
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
/// Returns the node at the specified `index` and its opening to the `root`.
|
||||
///
|
||||
/// The path starts at the sibling of the target leaf.
|
||||
///
|
||||
/// # Errors
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
|
||||
/// store.
|
||||
pub fn get_path(&self, root: RpoDigest, index: NodeIndex) -> Result<ValuePath, MerkleError> {
|
||||
let mut hash = root;
|
||||
let mut path = Vec::with_capacity(index.depth().into());
|
||||
|
||||
// corner case: check the root is in the store when called with index `NodeIndex::root()`
|
||||
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
|
||||
|
||||
for i in (0..index.depth()).rev() {
|
||||
let node = self
|
||||
.nodes
|
||||
.get(&hash)
|
||||
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
|
||||
|
||||
let bit = (index.value() >> i) & 1;
|
||||
hash = if bit == 0 {
|
||||
path.push(node.right);
|
||||
node.left
|
||||
} else {
|
||||
path.push(node.left);
|
||||
node.right
|
||||
}
|
||||
}
|
||||
|
||||
// the path is computed from root to leaf, so it must be reversed
|
||||
path.reverse();
|
||||
|
||||
Ok(ValuePath::new(hash, MerklePath::new(path)))
|
||||
}
|
||||
|
||||
// LEAF TRAVERSAL
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the depth of the first leaf or an empty node encountered while traversing the tree
|
||||
/// from the specified root down according to the provided index.
|
||||
///
|
||||
/// The `tree_depth` parameter specifies the depth of the tree rooted at `root`. The
|
||||
/// maximum value the argument accepts is [u64::BITS].
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return an error if:
|
||||
/// - The provided root is not found.
|
||||
/// - The provided `tree_depth` is greater than 64.
|
||||
/// - The provided `index` is not valid for a depth equivalent to `tree_depth`.
|
||||
/// - No leaf or an empty node was found while traversing the tree down to `tree_depth`.
|
||||
pub fn get_leaf_depth(
|
||||
&self,
|
||||
root: RpoDigest,
|
||||
tree_depth: u8,
|
||||
index: u64,
|
||||
) -> Result<u8, MerkleError> {
|
||||
// validate depth and index
|
||||
if tree_depth > 64 {
|
||||
return Err(MerkleError::DepthTooBig(tree_depth as u64));
|
||||
}
|
||||
NodeIndex::new(tree_depth, index)?;
|
||||
|
||||
// check if the root exists, providing the proper error report if it doesn't
|
||||
let empty = EmptySubtreeRoots::empty_hashes(tree_depth);
|
||||
let mut hash = root;
|
||||
if !self.nodes.contains_key(&hash) {
|
||||
return Err(MerkleError::RootNotInStore(hash));
|
||||
}
|
||||
|
||||
// we traverse from root to leaf, so the path is reversed
|
||||
let mut path = (index << (64 - tree_depth)).reverse_bits();
|
||||
|
||||
// iterate every depth and reconstruct the path from root to leaf
|
||||
for depth in 0..=tree_depth {
|
||||
// we short-circuit if an empty node has been found
|
||||
if hash == empty[depth as usize] {
|
||||
return Ok(depth);
|
||||
}
|
||||
|
||||
// fetch the children pair, mapped by its parent hash
|
||||
let children = match self.nodes.get(&hash) {
|
||||
Some(node) => node,
|
||||
None => return Ok(depth),
|
||||
};
|
||||
|
||||
// traverse down
|
||||
hash = if path & 1 == 0 { children.left } else { children.right };
|
||||
path >>= 1;
|
||||
}
|
||||
|
||||
// return an error because we exhausted the index but didn't find either a leaf or an
|
||||
// empty node
|
||||
Err(MerkleError::DepthTooBig(tree_depth as u64 + 1))
|
||||
}
|
||||
|
||||
/// Returns index and value of a leaf node which is the only leaf node in a subtree defined by
|
||||
/// the provided root. If the subtree contains zero or more than one leaf nodes None is
|
||||
/// returned.
|
||||
///
|
||||
/// The `tree_depth` parameter specifies the depth of the parent tree such that `root` is
|
||||
/// located in this tree at `root_index`. The maximum value the argument accepts is
|
||||
/// [u64::BITS].
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return an error if:
|
||||
/// - The provided root is not found.
|
||||
/// - The provided `tree_depth` is greater than 64.
|
||||
/// - The provided `root_index` has depth greater than `tree_depth`.
|
||||
/// - A lone node at depth `tree_depth` is not a leaf node.
|
||||
pub fn find_lone_leaf(
|
||||
&self,
|
||||
root: RpoDigest,
|
||||
root_index: NodeIndex,
|
||||
tree_depth: u8,
|
||||
) -> Result<Option<(NodeIndex, RpoDigest)>, MerkleError> {
|
||||
// we set max depth at u64::BITS as this is the largest meaningful value for a 64-bit index
|
||||
const MAX_DEPTH: u8 = u64::BITS as u8;
|
||||
if tree_depth > MAX_DEPTH {
|
||||
return Err(MerkleError::DepthTooBig(tree_depth as u64));
|
||||
}
|
||||
let empty = EmptySubtreeRoots::empty_hashes(MAX_DEPTH);
|
||||
|
||||
let mut node = root;
|
||||
if !self.nodes.contains_key(&node) {
|
||||
return Err(MerkleError::RootNotInStore(node));
|
||||
}
|
||||
|
||||
let mut index = root_index;
|
||||
if index.depth() > tree_depth {
|
||||
return Err(MerkleError::DepthTooBig(index.depth() as u64));
|
||||
}
|
||||
|
||||
// traverse down following the path of single non-empty nodes; this works because if a
|
||||
// node has two empty children it cannot contain a lone leaf. similarly if a node has
|
||||
// two non-empty children it must contain at least two leaves.
|
||||
for depth in index.depth()..tree_depth {
|
||||
// if the node is a leaf, return; otherwise, examine the node's children
|
||||
let children = match self.nodes.get(&node) {
|
||||
Some(node) => node,
|
||||
None => return Ok(Some((index, node))),
|
||||
};
|
||||
|
||||
let empty_node = empty[depth as usize + 1];
|
||||
node = if children.left != empty_node && children.right == empty_node {
|
||||
index = index.left_child();
|
||||
children.left
|
||||
} else if children.left == empty_node && children.right != empty_node {
|
||||
index = index.right_child();
|
||||
children.right
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
}
|
||||
|
||||
// if we are here, we got to `tree_depth`; thus, either the current node is a leaf node,
|
||||
// and so we return it, or it is an internal node, and then we return an error
|
||||
if self.nodes.contains_key(&node) {
|
||||
Err(MerkleError::DepthTooBig(tree_depth as u64 + 1))
|
||||
} else {
|
||||
Ok(Some((index, node)))
|
||||
}
|
||||
}
|
||||
|
||||
// DATA EXTRACTORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a subset of this Merkle store such that the returned Merkle store contains all
|
||||
/// nodes which are descendants of the specified roots.
|
||||
///
|
||||
/// The roots for which no descendants exist in this Merkle store are ignored.
|
||||
pub fn subset<I, R>(&self, roots: I) -> MerkleStore<T>
|
||||
where
|
||||
I: Iterator<Item = R>,
|
||||
R: Borrow<RpoDigest>,
|
||||
{
|
||||
let mut store = MerkleStore::new();
|
||||
for root in roots {
|
||||
let root = *root.borrow();
|
||||
store.clone_tree_from(root, self);
|
||||
}
|
||||
store
|
||||
}
|
||||
|
||||
/// Iterator over the inner nodes of the [MerkleStore].
|
||||
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
|
||||
self.nodes
|
||||
.iter()
|
||||
.map(|(r, n)| InnerNodeInfo { value: *r, left: n.left, right: n.right })
|
||||
}
|
||||
|
||||
/// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root`
|
||||
/// and `max_depth`.
|
||||
pub fn non_empty_leaves(
|
||||
&self,
|
||||
root: RpoDigest,
|
||||
max_depth: u8,
|
||||
) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
|
||||
let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth);
|
||||
let mut stack = Vec::new();
|
||||
stack.push((NodeIndex::new_unchecked(0, 0), root));
|
||||
|
||||
core::iter::from_fn(move || {
|
||||
while let Some((index, node_hash)) = stack.pop() {
|
||||
// if we are at the max depth then we have reached a leaf
|
||||
if index.depth() == max_depth {
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
|
||||
// fetch the nodes children and push them onto the stack if they are not the roots
|
||||
// of empty subtrees
|
||||
if let Some(node) = self.nodes.get(&node_hash) {
|
||||
if !empty_roots.contains(&node.left) {
|
||||
stack.push((index.left_child(), node.left));
|
||||
}
|
||||
if !empty_roots.contains(&node.right) {
|
||||
stack.push((index.right_child(), node.right));
|
||||
}
|
||||
|
||||
// if the node is not in the store assume it is a leaf
|
||||
} else {
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Adds all the nodes of a Merkle path represented by `path`, opening to `node`. Returns the
|
||||
/// new root.
|
||||
///
|
||||
/// This will compute the sibling elements determined by the Merkle `path` and `node`, and
|
||||
/// include all the nodes into the store.
|
||||
pub fn add_merkle_path(
|
||||
&mut self,
|
||||
index: u64,
|
||||
node: RpoDigest,
|
||||
path: MerklePath,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
let root = path.inner_nodes(index, node)?.fold(RpoDigest::default(), |_, node| {
|
||||
let value: RpoDigest = node.value;
|
||||
let left: RpoDigest = node.left;
|
||||
let right: RpoDigest = node.right;
|
||||
|
||||
debug_assert_eq!(Rpo256::merge(&[left, right]), value);
|
||||
self.nodes.insert(value, StoreNode { left, right });
|
||||
|
||||
node.value
|
||||
});
|
||||
Ok(root)
|
||||
}
|
||||
|
||||
/// Adds all the nodes of multiple Merkle paths into the store.
|
||||
///
|
||||
/// This will compute the sibling elements for each Merkle `path` and include all the nodes
|
||||
/// into the store.
|
||||
///
|
||||
/// For further reference, check [MerkleStore::add_merkle_path].
|
||||
pub fn add_merkle_paths<I>(&mut self, paths: I) -> Result<(), MerkleError>
|
||||
where
|
||||
I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
|
||||
{
|
||||
for (index_value, node, path) in paths.into_iter() {
|
||||
self.add_merkle_path(index_value, node, path)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sets a node to `value`.
|
||||
///
|
||||
/// # Errors
|
||||
/// This method can return the following errors:
|
||||
/// - `RootNotInStore` if the `root` is not present in the store.
|
||||
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
|
||||
/// store.
|
||||
pub fn set_node(
|
||||
&mut self,
|
||||
mut root: RpoDigest,
|
||||
index: NodeIndex,
|
||||
value: RpoDigest,
|
||||
) -> Result<RootPath, MerkleError> {
|
||||
let node = value;
|
||||
let ValuePath { value, path } = self.get_path(root, index)?;
|
||||
|
||||
// performs the update only if the node value differs from the opening
|
||||
if node != value {
|
||||
root = self.add_merkle_path(index.value(), node, path.clone())?;
|
||||
}
|
||||
|
||||
Ok(RootPath { root, path })
|
||||
}
|
||||
|
||||
/// Merges two elements and adds the resulting node into the store.
|
||||
///
|
||||
/// Merges arbitrary values. They may be leafs, nodes, or a mixture of both.
|
||||
pub fn merge_roots(
|
||||
&mut self,
|
||||
left_root: RpoDigest,
|
||||
right_root: RpoDigest,
|
||||
) -> Result<RpoDigest, MerkleError> {
|
||||
let parent = Rpo256::merge(&[left_root, right_root]);
|
||||
self.nodes.insert(parent, StoreNode { left: left_root, right: right_root });
|
||||
|
||||
Ok(parent)
|
||||
}
|
||||
|
||||
// DESTRUCTURING
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns the inner storage of this MerkleStore while consuming `self`.
|
||||
pub fn into_inner(self) -> T {
|
||||
self.nodes
|
||||
}
|
||||
|
||||
// HELPER METHODS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Recursively clones a tree with the specified root from the specified source into self.
|
||||
///
|
||||
/// If the source store does not contain a tree with the specified root, this is a noop.
|
||||
fn clone_tree_from(&mut self, root: RpoDigest, source: &Self) {
|
||||
// process the node only if it is in the source
|
||||
if let Some(node) = source.nodes.get(&root) {
|
||||
// if the node has already been inserted, no need to process it further as all of its
|
||||
// descendants should be already cloned from the source store
|
||||
if self.nodes.insert(root, *node).is_none() {
|
||||
self.clone_tree_from(node.left, source);
|
||||
self.clone_tree_from(node.right, source);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&MerkleTree> for MerkleStore<T> {
|
||||
fn from(value: &MerkleTree) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>, const DEPTH: u8> From<&SimpleSmt<DEPTH>> for MerkleStore<T> {
|
||||
fn from(value: &SimpleSmt<DEPTH>) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&Smt> for MerkleStore<T> {
|
||||
fn from(value: &Smt) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&Mmr> for MerkleStore<T> {
|
||||
fn from(value: &Mmr) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<&PartialMerkleTree> for MerkleStore<T> {
|
||||
fn from(value: &PartialMerkleTree) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(value.inner_nodes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> From<T> for MerkleStore<T> {
|
||||
fn from(values: T) -> Self {
|
||||
let nodes = values.into_iter().chain(empty_hashes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<InnerNodeInfo> for MerkleStore<T> {
|
||||
fn from_iter<I: IntoIterator<Item = InnerNodeInfo>>(iter: I) -> Self {
|
||||
let nodes = combine_nodes_with_empty_hashes(iter).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<(RpoDigest, StoreNode)> for MerkleStore<T> {
|
||||
fn from_iter<I: IntoIterator<Item = (RpoDigest, StoreNode)>>(iter: I) -> Self {
|
||||
let nodes = iter.into_iter().chain(empty_hashes()).collect();
|
||||
Self { nodes }
|
||||
}
|
||||
}
|
||||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
|
||||
fn extend<I: IntoIterator<Item = InnerNodeInfo>>(&mut self, iter: I) {
|
||||
self.nodes.extend(
|
||||
iter.into_iter()
|
||||
.map(|info| (info.value, StoreNode { left: info.left, right: info.right })),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
impl Serializable for StoreNode {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.left.write_into(target);
|
||||
self.right.write_into(target);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for StoreNode {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let left = RpoDigest::read_from(source)?;
|
||||
let right = RpoDigest::read_from(source)?;
|
||||
Ok(StoreNode { left, right })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Serializable for MerkleStore<T> {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
target.write_u64(self.nodes.len() as u64);
|
||||
|
||||
for (k, v) in self.nodes.iter() {
|
||||
k.write_into(target);
|
||||
v.write_into(target);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Deserializable for MerkleStore<T> {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let len = source.read_u64()?;
|
||||
let mut nodes: Vec<(RpoDigest, StoreNode)> = Vec::with_capacity(len as usize);
|
||||
|
||||
for _ in 0..len {
|
||||
let key = RpoDigest::read_from(source)?;
|
||||
let value = StoreNode::read_from(source)?;
|
||||
nodes.push((key, value));
|
||||
}
|
||||
|
||||
Ok(nodes.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// ================================================================================================
|
||||
|
||||
/// Creates empty hashes for all the subtrees of a tree with a max depth of 255.
|
||||
fn empty_hashes() -> impl IntoIterator<Item = (RpoDigest, StoreNode)> {
|
||||
let subtrees = EmptySubtreeRoots::empty_hashes(255);
|
||||
subtrees
|
||||
.iter()
|
||||
.rev()
|
||||
.copied()
|
||||
.zip(subtrees.iter().rev().skip(1).copied())
|
||||
.map(|(child, parent)| (parent, StoreNode { left: child, right: child }))
|
||||
}
|
||||
|
||||
/// Consumes an iterator of [InnerNodeInfo] and returns an iterator of `(value, node)` tuples
|
||||
/// which includes the nodes associate with roots of empty subtrees up to a depth of 255.
|
||||
fn combine_nodes_with_empty_hashes(
|
||||
nodes: impl IntoIterator<Item = InnerNodeInfo>,
|
||||
) -> impl Iterator<Item = (RpoDigest, StoreNode)> {
|
||||
nodes
|
||||
.into_iter()
|
||||
.map(|info| (info.value, StoreNode { left: info.left, right: info.right }))
|
||||
.chain(empty_hashes())
|
||||
}
|
933
src/merkle/store/tests.rs
Normal file
933
src/merkle/store/tests.rs
Normal file
|
@ -0,0 +1,933 @@
|
|||
use assert_matches::assert_matches;
|
||||
use seq_macro::seq;
|
||||
#[cfg(feature = "std")]
|
||||
use {
|
||||
super::{Deserializable, Serializable},
|
||||
alloc::boxed::Box,
|
||||
std::error::Error,
|
||||
};
|
||||
|
||||
use super::{
|
||||
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
|
||||
PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, LeafIndex, MerkleTree, SimpleSmt, SMT_MAX_DEPTH,
|
||||
},
|
||||
Felt, Word, ONE, WORD_SIZE, ZERO,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
// ================================================================================================
|
||||
|
||||
const KEYS4: [u64; 4] = [0, 1, 2, 3];
|
||||
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
|
||||
|
||||
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
const VALUES8: [RpoDigest; 8] = [
|
||||
int_to_node(1),
|
||||
int_to_node(2),
|
||||
int_to_node(3),
|
||||
int_to_node(4),
|
||||
int_to_node(5),
|
||||
int_to_node(6),
|
||||
int_to_node(7),
|
||||
int_to_node(8),
|
||||
];
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn test_root_not_in_store() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
assert_matches!(
|
||||
store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
|
||||
"Leaf 0 is not a root"
|
||||
);
|
||||
assert_matches!(
|
||||
store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
|
||||
Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
|
||||
"Leaf 0 is not a root"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merkle_tree() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
|
||||
// STORE LEAVES ARE CORRECT -------------------------------------------------------------------
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
VALUES4[0],
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
VALUES4[1],
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
VALUES4[2],
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
VALUES4[3],
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH TREE --------------------------------------------------------------------
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
"node 0 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
"node 1 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
"node 2 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_node(NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
"node 3 must be the same for both MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 0)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 1)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 2)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
mtree.get_path(NodeIndex::make(mtree.depth(), 3)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_roots() {
|
||||
let store = MerkleStore::default();
|
||||
let mut root = RpoDigest::default();
|
||||
|
||||
for depth in 0..255 {
|
||||
root = Rpo256::merge(&[root; 2]);
|
||||
assert!(
|
||||
store.get_node(root, NodeIndex::make(0, 0)).is_ok(),
|
||||
"The root of the empty tree of depth {depth} must be registered"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_leaf_paths_for_empty_trees() -> Result<(), MerkleError> {
|
||||
let store = MerkleStore::default();
|
||||
|
||||
// Starts at 1 because leafs are not included in the store.
|
||||
// Ends at 64 because it is not possible to represent an index of a depth greater than 64,
|
||||
// because a u64 is used to index the leaf.
|
||||
seq!(DEPTH in 1_u8..64_u8 {
|
||||
let smt = SimpleSmt::<DEPTH>::new()?;
|
||||
|
||||
let index = NodeIndex::make(DEPTH, 0);
|
||||
let store_path = store.get_path(smt.root(), index)?;
|
||||
let smt_path = smt.open(&LeafIndex::<DEPTH>::new(0)?).path;
|
||||
assert_eq!(
|
||||
store_path.value,
|
||||
RpoDigest::default(),
|
||||
"the leaf of an empty tree is always ZERO"
|
||||
);
|
||||
assert_eq!(
|
||||
store_path.path, smt_path,
|
||||
"the returned merkle path does not match the computed values"
|
||||
);
|
||||
assert_eq!(
|
||||
store_path.path.compute_root(DEPTH.into(), RpoDigest::default()).unwrap(),
|
||||
smt.root(),
|
||||
"computed root from the path must match the empty tree root"
|
||||
);
|
||||
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_invalid_node() {
|
||||
let mtree =
|
||||
MerkleTree::new(digests_to_words(&VALUES4)).expect("creating a merkle tree must work");
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let _ = store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_sparse_merkle_tree_one_level() -> Result<(), MerkleError> {
|
||||
let keys2: [u64; 2] = [0, 1];
|
||||
let leaves2: [Word; 2] = [int_to_leaf(1), int_to_leaf(2)];
|
||||
let smt = SimpleSmt::<1>::with_leaves(keys2.into_iter().zip(leaves2)).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
let idx = NodeIndex::make(1, 0);
|
||||
assert_eq!(smt.get_node(idx).unwrap(), leaves2[0].into());
|
||||
assert_eq!(store.get_node(smt.root(), idx).unwrap(), smt.get_node(idx).unwrap());
|
||||
|
||||
let idx = NodeIndex::make(1, 1);
|
||||
assert_eq!(smt.get_node(idx).unwrap(), leaves2[1].into());
|
||||
assert_eq!(store.get_node(smt.root(), idx).unwrap(), smt.get_node(idx).unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
|
||||
let smt =
|
||||
SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4)))
|
||||
.unwrap();
|
||||
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
|
||||
VALUES4[0],
|
||||
"node 0 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
|
||||
VALUES4[1],
|
||||
"node 1 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
|
||||
VALUES4[2],
|
||||
"node 2 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
|
||||
VALUES4[3],
|
||||
"node 3 must be in the tree"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
|
||||
RpoDigest::default(),
|
||||
"unmodified node 4 must be ZERO"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH TREE ===============================================================
|
||||
// sanity check the values returned by the store and the tree
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
|
||||
"node 0 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
|
||||
"node 1 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
|
||||
"node 2 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
|
||||
"node 3 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
|
||||
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
|
||||
"node 4 must be the same for both SparseMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the tree
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(0).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 1 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(1).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 2 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(2).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 2 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 3 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(3).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 3 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap();
|
||||
assert_eq!(
|
||||
RpoDigest::default(),
|
||||
result.value,
|
||||
"Value for merkle path at index 4 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(4).unwrap()).path,
|
||||
result.path,
|
||||
"merkle path for index 4 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_merkle_paths() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
|
||||
let i0 = 0;
|
||||
let p0 = mtree.get_path(NodeIndex::make(2, i0)).unwrap();
|
||||
|
||||
let i1 = 1;
|
||||
let p1 = mtree.get_path(NodeIndex::make(2, i1)).unwrap();
|
||||
|
||||
let i2 = 2;
|
||||
let p2 = mtree.get_path(NodeIndex::make(2, i2)).unwrap();
|
||||
|
||||
let i3 = 3;
|
||||
let p3 = mtree.get_path(NodeIndex::make(2, i3)).unwrap();
|
||||
|
||||
let paths = [
|
||||
(i0, VALUES4[i0 as usize], p0),
|
||||
(i1, VALUES4[i1 as usize], p1),
|
||||
(i2, VALUES4[i2 as usize], p2),
|
||||
(i3, VALUES4[i3 as usize], p3),
|
||||
];
|
||||
|
||||
let mut store = MerkleStore::default();
|
||||
store.add_merkle_paths(paths.clone()).expect("the valid paths must work");
|
||||
|
||||
let pmt = PartialMerkleTree::with_paths(paths).unwrap();
|
||||
|
||||
// STORE LEAVES ARE CORRECT ==============================================================
|
||||
// checks the leaves in the store corresponds to the expected values
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
VALUES4[0],
|
||||
"node 0 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
VALUES4[1],
|
||||
"node 1 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
VALUES4[2],
|
||||
"node 2 must be in the pmt"
|
||||
);
|
||||
assert_eq!(
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
VALUES4[3],
|
||||
"node 3 must be in the pmt"
|
||||
);
|
||||
|
||||
// STORE LEAVES MATCH PMT ================================================================
|
||||
// sanity check the values returned by the store and the pmt
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
"node 0 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
"node 1 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
"node 2 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
"node 3 must be the same for both PartialMerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
// STORE MERKLE PATH MATCHES ==============================================================
|
||||
// assert the merkle path returned by the store is the same as the one in the pmt
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[0], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[1], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[2], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
let result = store.get_path(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap();
|
||||
assert_eq!(
|
||||
VALUES4[3], result.value,
|
||||
"Value for merkle path at index 0 must match leaf value"
|
||||
);
|
||||
assert_eq!(
|
||||
pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
|
||||
result.path,
|
||||
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wont_open_to_different_depth_root() {
|
||||
let empty = EmptySubtreeRoots::empty_hashes(64);
|
||||
let a = [ONE; 4];
|
||||
let b = [Felt::new(2); 4];
|
||||
|
||||
// Compute the root for a different depth. We cherry-pick this specific depth to prevent a
|
||||
// regression to a bug in the past that allowed the user to fetch a node at a depth lower than
|
||||
// the inserted path of a Merkle tree.
|
||||
let mut root = Rpo256::merge(&[a.into(), b.into()]);
|
||||
for depth in (1..=63).rev() {
|
||||
root = Rpo256::merge(&[root, empty[depth]]);
|
||||
}
|
||||
|
||||
// For this example, the depth of the Merkle tree is 1, as we have only two leaves. Here we
|
||||
// attempt to fetch a node on the maximum depth, and it should fail because the root shouldn't
|
||||
// exist for the set.
|
||||
let mtree = MerkleTree::new(vec![a, b]).unwrap();
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let index = NodeIndex::root();
|
||||
let err = store.get_node(root, index).err().unwrap();
|
||||
assert_matches!(err, MerkleError::RootNotInStore(err_root) if err_root == root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn store_path_opens_from_leaf() {
|
||||
let a = [ONE; 4];
|
||||
let b = [Felt::new(2); 4];
|
||||
let c = [Felt::new(3); 4];
|
||||
let d = [Felt::new(4); 4];
|
||||
let e = [Felt::new(5); 4];
|
||||
let f = [Felt::new(6); 4];
|
||||
let g = [Felt::new(7); 4];
|
||||
let h = [Felt::new(8); 4];
|
||||
|
||||
let i = Rpo256::merge(&[a.into(), b.into()]);
|
||||
let j = Rpo256::merge(&[c.into(), d.into()]);
|
||||
let k = Rpo256::merge(&[e.into(), f.into()]);
|
||||
let l = Rpo256::merge(&[g.into(), h.into()]);
|
||||
|
||||
let m = Rpo256::merge(&[i, j]);
|
||||
let n = Rpo256::merge(&[k, l]);
|
||||
|
||||
let root = Rpo256::merge(&[m, n]);
|
||||
|
||||
let mtree = MerkleTree::new(vec![a, b, c, d, e, f, g, h]).unwrap();
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let path = store.get_path(root, NodeIndex::make(3, 1)).unwrap().path;
|
||||
|
||||
let expected = MerklePath::new([a.into(), j, n].to_vec());
|
||||
assert_eq!(path, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_node() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let mut store = MerkleStore::from(&mtree);
|
||||
let value = int_to_node(42);
|
||||
let index = NodeIndex::make(mtree.depth(), 0);
|
||||
let new_root = store.set_node(mtree.root(), index, value)?.root;
|
||||
assert_eq!(store.get_node(new_root, index).unwrap(), value, "value must have changed");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constructors() -> Result<(), MerkleError> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
|
||||
let depth = mtree.depth();
|
||||
let leaves = 2u64.pow(depth.into());
|
||||
for index in 0..leaves {
|
||||
let index = NodeIndex::make(depth, index);
|
||||
let value_path = store.get_path(mtree.root(), index)?;
|
||||
assert_eq!(mtree.get_path(index)?, value_path.path);
|
||||
}
|
||||
|
||||
const DEPTH: u8 = 32;
|
||||
let smt =
|
||||
SimpleSmt::<DEPTH>::with_leaves(KEYS4.into_iter().zip(digests_to_words(&VALUES4))).unwrap();
|
||||
let store = MerkleStore::from(&smt);
|
||||
|
||||
for key in KEYS4 {
|
||||
let index = NodeIndex::make(DEPTH, key);
|
||||
let value_path = store.get_path(smt.root(), index)?;
|
||||
assert_eq!(smt.open(&LeafIndex::<DEPTH>::new(key).unwrap()).path, value_path.path);
|
||||
}
|
||||
|
||||
let d = 2;
|
||||
let paths = [
|
||||
(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0)).unwrap()),
|
||||
(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1)).unwrap()),
|
||||
(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2)).unwrap()),
|
||||
(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3)).unwrap()),
|
||||
];
|
||||
|
||||
let mut store1 = MerkleStore::default();
|
||||
store1.add_merkle_paths(paths.clone())?;
|
||||
|
||||
let mut store2 = MerkleStore::default();
|
||||
store2.add_merkle_path(0, VALUES4[0], mtree.get_path(NodeIndex::make(d, 0))?)?;
|
||||
store2.add_merkle_path(1, VALUES4[1], mtree.get_path(NodeIndex::make(d, 1))?)?;
|
||||
store2.add_merkle_path(2, VALUES4[2], mtree.get_path(NodeIndex::make(d, 2))?)?;
|
||||
store2.add_merkle_path(3, VALUES4[3], mtree.get_path(NodeIndex::make(d, 3))?)?;
|
||||
let pmt = PartialMerkleTree::with_paths(paths).unwrap();
|
||||
|
||||
for key in [0, 1, 2, 3] {
|
||||
let index = NodeIndex::make(d, key);
|
||||
let value_path1 = store1.get_path(pmt.root(), index)?;
|
||||
let value_path2 = store2.get_path(pmt.root(), index)?;
|
||||
assert_eq!(value_path1, value_path2);
|
||||
|
||||
let index = NodeIndex::make(d, key);
|
||||
assert_eq!(pmt.get_path(index)?, value_path1.path);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_path_should_be_truncated_by_midtier_insert() {
|
||||
let key = 0b11010010_11001100_11001100_11001100_11001100_11001100_11001100_11001100_u64;
|
||||
|
||||
let mut store = MerkleStore::new();
|
||||
let root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
|
||||
// insert first node - works as expected
|
||||
let depth = 64;
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
let index = NodeIndex::new(depth, key).unwrap();
|
||||
let root = store.set_node(root, index, node).unwrap().root;
|
||||
let result = store.get_node(root, index).unwrap();
|
||||
let path = store.get_path(root, index).unwrap().path;
|
||||
assert_eq!(node, result);
|
||||
assert_eq!(path.depth(), depth);
|
||||
assert!(path.verify(index.value(), result, &root).is_ok());
|
||||
|
||||
// flip the first bit of the key and insert the second node on a different depth
|
||||
let key = key ^ (1 << 63);
|
||||
let key = key >> 8;
|
||||
let depth = 56;
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
let index = NodeIndex::new(depth, key).unwrap();
|
||||
let root = store.set_node(root, index, node).unwrap().root;
|
||||
let result = store.get_node(root, index).unwrap();
|
||||
let path = store.get_path(root, index).unwrap().path;
|
||||
assert_eq!(node, result);
|
||||
assert_eq!(path.depth(), depth);
|
||||
assert!(path.verify(index.value(), result, &root).is_ok());
|
||||
|
||||
// attempt to fetch a path of the second node to depth 64
|
||||
// should fail because the previously inserted node will remove its sub-tree from the set
|
||||
let key = key << 8;
|
||||
let index = NodeIndex::new(64, key).unwrap();
|
||||
assert!(store.get_node(root, index).is_err());
|
||||
}
|
||||
|
||||
// LEAF TRAVERSAL
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn get_leaf_depth_works_depth_64() {
|
||||
let mut store = MerkleStore::new();
|
||||
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
let key = u64::MAX;
|
||||
|
||||
// this will create a rainbow tree and test all opening to depth 64
|
||||
for d in 0..64 {
|
||||
let k = key & (u64::MAX >> d);
|
||||
let node = RpoDigest::from([Felt::new(k); WORD_SIZE]);
|
||||
let index = NodeIndex::new(64, k).unwrap();
|
||||
|
||||
// assert the leaf doesn't exist before the insert. the returned depth should always
|
||||
// increment with the paths count of the set, as they are intersecting one another up to
|
||||
// the first bits of the used key.
|
||||
assert_eq!(d, store.get_leaf_depth(root, 64, k).unwrap());
|
||||
|
||||
// insert and assert the correct depth
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(64, store.get_leaf_depth(root, 64, k).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_leaf_depth_works_with_incremental_depth() {
|
||||
let mut store = MerkleStore::new();
|
||||
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(64)[0];
|
||||
|
||||
// insert some path to the left of the root and assert it
|
||||
let key = 0b01001011_10110110_00001101_01110100_00111011_10101101_00000100_01000001_u64;
|
||||
assert_eq!(0, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let depth = 64;
|
||||
let index = NodeIndex::new(depth, key).unwrap();
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
|
||||
// flip the key to the right of the root and insert some content on depth 16
|
||||
let key = 0b11001011_10110110_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
assert_eq!(1, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let depth = 16;
|
||||
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
|
||||
// attempt the sibling of the previous leaf
|
||||
let key = 0b11001011_10110111_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
assert_eq!(16, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
|
||||
// move down to the next depth and assert correct behavior
|
||||
let key = 0b11001011_10110100_00000000_00000000_00000000_00000000_00000000_00000000_u64;
|
||||
assert_eq!(15, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
let depth = 17;
|
||||
let index = NodeIndex::new(depth, key >> (64 - depth)).unwrap();
|
||||
let node = RpoDigest::from([Felt::new(key); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
assert_eq!(depth, store.get_leaf_depth(root, 64, key).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_leaf_depth_works_with_depth_8() {
|
||||
let mut store = MerkleStore::new();
|
||||
let mut root: RpoDigest = EmptySubtreeRoots::empty_hashes(8)[0];
|
||||
|
||||
// insert some random, 8 depth keys. `a` diverges from the first bit
|
||||
let a = 0b01101001_u64;
|
||||
let b = 0b10011001_u64;
|
||||
let c = 0b10010110_u64;
|
||||
let d = 0b11110110_u64;
|
||||
|
||||
for k in [a, b, c, d] {
|
||||
let index = NodeIndex::new(8, k).unwrap();
|
||||
let node = RpoDigest::from([Felt::new(k); WORD_SIZE]);
|
||||
root = store.set_node(root, index, node).unwrap().root;
|
||||
}
|
||||
|
||||
// assert all leaves returns the inserted depth
|
||||
for k in [a, b, c, d] {
|
||||
assert_eq!(8, store.get_leaf_depth(root, 8, k).unwrap());
|
||||
}
|
||||
|
||||
// flip last bit of a and expect it to return the same depth, but for an empty node
|
||||
assert_eq!(8, store.get_leaf_depth(root, 8, 0b01101000_u64).unwrap());
|
||||
|
||||
// flip fourth bit of a and expect an empty node on depth 4
|
||||
assert_eq!(4, store.get_leaf_depth(root, 8, 0b01111001_u64).unwrap());
|
||||
|
||||
// flip third bit of a and expect an empty node on depth 3
|
||||
assert_eq!(3, store.get_leaf_depth(root, 8, 0b01001001_u64).unwrap());
|
||||
|
||||
// flip second bit of a and expect an empty node on depth 2
|
||||
assert_eq!(2, store.get_leaf_depth(root, 8, 0b00101001_u64).unwrap());
|
||||
|
||||
// flip fourth bit of c and expect an empty node on depth 4
|
||||
assert_eq!(4, store.get_leaf_depth(root, 8, 0b10000110_u64).unwrap());
|
||||
|
||||
// flip second bit of d and expect an empty node on depth 3 as depth 2 conflicts with b and c
|
||||
assert_eq!(3, store.get_leaf_depth(root, 8, 0b10110110_u64).unwrap());
|
||||
|
||||
// duplicate the tree on `a` and assert the depth is short-circuited by such sub-tree
|
||||
let index = NodeIndex::new(8, a).unwrap();
|
||||
root = store.set_node(root, index, root).unwrap().root;
|
||||
assert_matches!(store.get_leaf_depth(root, 8, a).unwrap_err(), MerkleError::DepthTooBig(9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_lone_leaf() {
|
||||
let mut store = MerkleStore::new();
|
||||
let empty = EmptySubtreeRoots::empty_hashes(64);
|
||||
let mut root: RpoDigest = empty[0];
|
||||
|
||||
// insert a single leaf into the store at depth 64
|
||||
let key_a = 0b01010101_10101010_00001111_01110100_00111011_10101101_00000100_01000001_u64;
|
||||
let idx_a = NodeIndex::make(64, key_a);
|
||||
let val_a = RpoDigest::from([ONE, ONE, ONE, ONE]);
|
||||
root = store.set_node(root, idx_a, val_a).unwrap().root;
|
||||
|
||||
// for every ancestor of A, A should be a long leaf
|
||||
for depth in 1..64 {
|
||||
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, Some((idx_a, val_a)));
|
||||
}
|
||||
|
||||
// insert another leaf into the store such that it has the same 8 bit prefix as A
|
||||
let key_b = 0b01010101_01111010_00001111_01110100_00111011_10101101_00000100_01000001_u64;
|
||||
let idx_b = NodeIndex::make(64, key_b);
|
||||
let val_b = RpoDigest::from([ONE, ONE, ONE, ZERO]);
|
||||
root = store.set_node(root, idx_b, val_b).unwrap().root;
|
||||
|
||||
// for any node which is common between A and B, find_lone_leaf() should return None as the
|
||||
// node has two descendants
|
||||
for depth in 1..9 {
|
||||
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
|
||||
// for other ancestors of A and B, A and B should be lone leaves respectively
|
||||
for depth in 9..64 {
|
||||
let parent_index = NodeIndex::make(depth, key_a >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, Some((idx_a, val_a)));
|
||||
}
|
||||
|
||||
for depth in 9..64 {
|
||||
let parent_index = NodeIndex::make(depth, key_b >> (64 - depth));
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, Some((idx_b, val_b)));
|
||||
}
|
||||
|
||||
// for any other node, find_lone_leaf() should return None as they have no leaf nodes
|
||||
let parent_index = NodeIndex::make(16, 0b01010101_11111111);
|
||||
let parent = store.get_node(root, parent_index).unwrap();
|
||||
let res = store.find_lone_leaf(parent, parent_index, 64).unwrap();
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
|
||||
// SUBSET EXTRACTION
|
||||
// ================================================================================================
|
||||
|
||||
#[test]
|
||||
fn mstore_subset() {
|
||||
// add a Merkle tree of depth 3 to the store
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();
|
||||
let mut store = MerkleStore::default();
|
||||
let empty_store_num_nodes = store.nodes.len();
|
||||
store.extend(mtree.inner_nodes());
|
||||
|
||||
// build 3 subtrees contained within the above Merkle tree; note that subtree2 is a subset
|
||||
// of subtree1
|
||||
let subtree1 = MerkleTree::new(digests_to_words(&VALUES8[..4])).unwrap();
|
||||
let subtree2 = MerkleTree::new(digests_to_words(&VALUES8[2..4])).unwrap();
|
||||
let subtree3 = MerkleTree::new(digests_to_words(&VALUES8[6..])).unwrap();
|
||||
|
||||
// --- extract all 3 subtrees ---------------------------------------------
|
||||
|
||||
let substore = store.subset([subtree1.root(), subtree2.root(), subtree3.root()].iter());
|
||||
|
||||
// number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3
|
||||
assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4);
|
||||
|
||||
// make sure paths that all subtrees are in the store
|
||||
check_mstore_subtree(&substore, &subtree1);
|
||||
check_mstore_subtree(&substore, &subtree2);
|
||||
check_mstore_subtree(&substore, &subtree3);
|
||||
|
||||
// --- extract subtrees 1 and 3 -------------------------------------------
|
||||
// this should give the same result as above as subtree2 is nested within subtree1
|
||||
|
||||
let substore = store.subset([subtree1.root(), subtree3.root()].iter());
|
||||
|
||||
// number of nodes should increase by 4: 3 nodes form subtree1 and 1 node from subtree3
|
||||
assert_eq!(substore.nodes.len(), empty_store_num_nodes + 4);
|
||||
|
||||
// make sure paths that all subtrees are in the store
|
||||
check_mstore_subtree(&substore, &subtree1);
|
||||
check_mstore_subtree(&substore, &subtree2);
|
||||
check_mstore_subtree(&substore, &subtree3);
|
||||
}
|
||||
|
||||
fn check_mstore_subtree(store: &MerkleStore, subtree: &MerkleTree) {
|
||||
for (i, value) in subtree.leaves() {
|
||||
let index = NodeIndex::new(subtree.depth(), i).unwrap();
|
||||
let path1 = store.get_path(subtree.root(), index).unwrap();
|
||||
assert_eq!(*path1.value, *value);
|
||||
|
||||
let path2 = subtree.get_path(index).unwrap();
|
||||
assert_eq!(path1.path, path2);
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_serialization() -> Result<(), Box<dyn Error>> {
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
|
||||
let store = MerkleStore::from(&mtree);
|
||||
let decoded = MerkleStore::read_from_bytes(&store.to_bytes()).expect("deserialization failed");
|
||||
assert_eq!(store, decoded);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// MERKLE RECORDER
|
||||
// ================================================================================================
|
||||
#[test]
|
||||
fn test_recorder() {
|
||||
// instantiate recorder from MerkleTree and SimpleSmt
|
||||
let mtree = MerkleTree::new(digests_to_words(&VALUES4)).unwrap();
|
||||
|
||||
const TREE_DEPTH: u8 = 64;
|
||||
let smtree = SimpleSmt::<TREE_DEPTH>::with_leaves(
|
||||
KEYS8.into_iter().zip(VALUES8.into_iter().map(|x| x.into()).rev()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut recorder: RecordingMerkleStore =
|
||||
mtree.inner_nodes().chain(smtree.inner_nodes()).collect();
|
||||
|
||||
// get nodes from both trees and make sure they are correct
|
||||
let index_0 = NodeIndex::new(mtree.depth(), 0).unwrap();
|
||||
let node = recorder.get_node(mtree.root(), index_0).unwrap();
|
||||
assert_eq!(node, mtree.get_node(index_0).unwrap());
|
||||
|
||||
let index_1 = NodeIndex::new(TREE_DEPTH, 1).unwrap();
|
||||
let node = recorder.get_node(smtree.root(), index_1).unwrap();
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
// insert a value and assert that when we request it next time it is accurate
|
||||
let new_value = [ZERO, ZERO, ONE, ONE].into();
|
||||
let index_2 = NodeIndex::new(TREE_DEPTH, 2).unwrap();
|
||||
let root = recorder.set_node(smtree.root(), index_2, new_value).unwrap().root;
|
||||
assert_eq!(recorder.get_node(root, index_2).unwrap(), new_value);
|
||||
|
||||
// construct the proof
|
||||
let rec_map = recorder.into_inner();
|
||||
let (_, proof) = rec_map.finalize();
|
||||
let merkle_store: MerkleStore = proof.into();
|
||||
|
||||
// make sure the proof contains all nodes from both trees
|
||||
let node = merkle_store.get_node(mtree.root(), index_0).unwrap();
|
||||
assert_eq!(node, mtree.get_node(index_0).unwrap());
|
||||
|
||||
let node = merkle_store.get_node(smtree.root(), index_1).unwrap();
|
||||
assert_eq!(node, smtree.get_node(index_1).unwrap());
|
||||
|
||||
let node = merkle_store.get_node(smtree.root(), index_2).unwrap();
|
||||
assert_eq!(
|
||||
node,
|
||||
smtree.get_leaf(&LeafIndex::<TREE_DEPTH>::try_from(index_2).unwrap()).into()
|
||||
);
|
||||
|
||||
// assert that is doesnt contain nodes that were not recorded
|
||||
let not_recorded_index = NodeIndex::new(TREE_DEPTH, 4).unwrap();
|
||||
assert!(merkle_store.get_node(smtree.root(), not_recorded_index).is_err());
|
||||
assert!(smtree.get_node(not_recorded_index).is_ok());
|
||||
}
|
23
src/rand/mod.rs
Normal file
23
src/rand/mod.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
//! Pseudo-random element generation.
|
||||
|
||||
use rand::RngCore;
|
||||
pub use winter_crypto::{DefaultRandomCoin as WinterRandomCoin, RandomCoin, RandomCoinError};
|
||||
pub use winter_utils::Randomizable;
|
||||
|
||||
use crate::{Felt, FieldElement, Word, ZERO};
|
||||
|
||||
mod rpo;
|
||||
mod rpx;
|
||||
pub use rpo::RpoRandomCoin;
|
||||
pub use rpx::RpxRandomCoin;
|
||||
|
||||
/// Pseudo-random element generator.
|
||||
///
|
||||
/// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s.
|
||||
pub trait FeltRng: RngCore {
|
||||
/// Draw, uniformly at random, a base field element.
|
||||
fn draw_element(&mut self) -> Felt;
|
||||
|
||||
/// Draw, uniformly at random, a [Word].
|
||||
fn draw_word(&mut self) -> Word;
|
||||
}
|
296
src/rand/rpo.rs
Normal file
296
src/rand/rpo.rs
Normal file
|
@ -0,0 +1,296 @@
|
|||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use rand_core::impls;
|
||||
|
||||
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const STATE_WIDTH: usize = Rpo256::STATE_WIDTH;
|
||||
const RATE_START: usize = Rpo256::RATE_RANGE.start;
|
||||
const RATE_END: usize = Rpo256::RATE_RANGE.end;
|
||||
const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.start) / 2;
|
||||
|
||||
// RPO RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in <https://eprint.iacr.org/2011/499.pdf>.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
|
||||
/// is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct RpoRandomCoin {
|
||||
state: [Felt; STATE_WIDTH],
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl RpoRandomCoin {
|
||||
/// Returns a new [RpoRandomCoin] initialize with the specified seed.
|
||||
pub fn new(seed: Word) -> Self {
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
for i in 0..HALF_RATE_WIDTH {
|
||||
state[RATE_START + i] += seed[i];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut state);
|
||||
|
||||
RpoRandomCoin { state, current: RATE_START }
|
||||
}
|
||||
|
||||
/// Returns an [RpoRandomCoin] instantiated from the provided components.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||
assert!(
|
||||
(RATE_START..RATE_END).contains(¤t),
|
||||
"current value outside of valid range"
|
||||
);
|
||||
Self { state, current }
|
||||
}
|
||||
|
||||
/// Returns components of this random coin.
|
||||
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
/// Fills `dest` with random data.
|
||||
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
<Self as RngCore>::fill_bytes(self, dest)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
self.current = RATE_START;
|
||||
}
|
||||
|
||||
self.current += 1;
|
||||
self.state[self.current - 1]
|
||||
}
|
||||
}
|
||||
|
||||
// RANDOM COIN IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RandomCoin for RpoRandomCoin {
|
||||
type BaseField = Felt;
|
||||
type Hasher = Rpo256;
|
||||
|
||||
fn new(seed: &[Self::BaseField]) -> Self {
|
||||
let digest: Word = Rpo256::hash_elements(seed).into();
|
||||
Self::new(digest)
|
||||
}
|
||||
|
||||
fn reseed(&mut self, data: RpoDigest) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPO state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
// Absorb
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
}
|
||||
|
||||
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||
let value = Felt::new(value);
|
||||
let mut state_tmp = self.state;
|
||||
|
||||
state_tmp[RATE_START] += value;
|
||||
|
||||
Rpo256::apply_permutation(&mut state_tmp);
|
||||
|
||||
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||
first_rate_element.trailing_zeros()
|
||||
}
|
||||
|
||||
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||
let ext_degree = E::EXTENSION_DEGREE;
|
||||
let mut result = vec![ZERO; ext_degree];
|
||||
for r in result.iter_mut().take(ext_degree) {
|
||||
*r = self.draw_basefield();
|
||||
}
|
||||
|
||||
let result = E::slice_from_base_elements(&result);
|
||||
Ok(result[0])
|
||||
}
|
||||
|
||||
fn draw_integers(
|
||||
&mut self,
|
||||
num_values: usize,
|
||||
domain_size: usize,
|
||||
nonce: u64,
|
||||
) -> Result<Vec<usize>, RandomCoinError> {
|
||||
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||
|
||||
// absorb the nonce
|
||||
let nonce = Felt::new(nonce);
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpo256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer and move the next random element pointer to the second rate element.
|
||||
// this is done as the first rate element will be "biased" via the provided `nonce` to
|
||||
// contain some number of leading zeros.
|
||||
self.current = RATE_START + 1;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
|
||||
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..1000 {
|
||||
// get the next pseudo-random field element
|
||||
let value = self.draw_basefield().as_int();
|
||||
|
||||
// use the mask to get a value within the range
|
||||
let value = (value & v_mask) as usize;
|
||||
|
||||
values.push(value);
|
||||
if values.len() == num_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if values.len() < num_values {
|
||||
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl FeltRng for RpoRandomCoin {
|
||||
fn draw_element(&mut self) -> Felt {
|
||||
self.draw_basefield()
|
||||
}
|
||||
|
||||
fn draw_word(&mut self) -> Word {
|
||||
let mut output = [ZERO; 4];
|
||||
for o in output.iter_mut() {
|
||||
*o = self.draw_basefield();
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// RNGCORE IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RngCore for RpoRandomCoin {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.draw_basefield().as_int() as u32
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
impls::next_u64_via_u32(self)
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
impls::fill_bytes_via_next(self, dest)
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
self.fill_bytes(dest);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl Serializable for RpoRandomCoin {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.state.iter().for_each(|v| v.write_into(target));
|
||||
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||
target.write_u8(self.current as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpoRandomCoin {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let state = [
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
];
|
||||
let current = source.read_u8()? as usize;
|
||||
if !(RATE_START..RATE_END).contains(¤t) {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"current value outside of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Self { state, current })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, FeltRng, RpoRandomCoin, Serializable, ZERO};
|
||||
use crate::ONE;
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_felt() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_element();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let expected = rpocoin.draw_basefield();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_word() {
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let output = rpocoin.draw_word();
|
||||
|
||||
let mut rpocoin = RpoRandomCoin::new([ZERO; 4]);
|
||||
let mut expected = [ZERO; 4];
|
||||
for o in expected.iter_mut() {
|
||||
*o = rpocoin.draw_basefield();
|
||||
}
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_serialization() {
|
||||
let coin1 = RpoRandomCoin::from_parts([ONE; 12], 5);
|
||||
|
||||
let bytes = coin1.to_bytes();
|
||||
let coin2 = RpoRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(coin1, coin2);
|
||||
}
|
||||
}
|
294
src/rand/rpx.rs
Normal file
294
src/rand/rpx.rs
Normal file
|
@ -0,0 +1,294 @@
|
|||
use alloc::{string::ToString, vec::Vec};
|
||||
|
||||
use rand_core::impls;
|
||||
|
||||
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
|
||||
use crate::{
|
||||
hash::rpx::{Rpx256, RpxDigest},
|
||||
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
|
||||
};
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
const STATE_WIDTH: usize = Rpx256::STATE_WIDTH;
|
||||
const RATE_START: usize = Rpx256::RATE_RANGE.start;
|
||||
const RATE_END: usize = Rpx256::RATE_RANGE.end;
|
||||
const HALF_RATE_WIDTH: usize = (Rpx256::RATE_RANGE.end - Rpx256::RATE_RANGE.start) / 2;
|
||||
|
||||
// RPX RANDOM COIN
|
||||
// ================================================================================================
|
||||
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
|
||||
/// described in <https://eprint.iacr.org/2011/499.pdf>.
|
||||
///
|
||||
/// The simplification is related to the following facts:
|
||||
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
|
||||
/// is possible because in our case we never reseed with more than 4 field elements.
|
||||
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
|
||||
/// material.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct RpxRandomCoin {
|
||||
state: [Felt; STATE_WIDTH],
|
||||
current: usize,
|
||||
}
|
||||
|
||||
impl RpxRandomCoin {
|
||||
/// Returns a new [RpxRandomCoin] initialize with the specified seed.
|
||||
pub fn new(seed: Word) -> Self {
|
||||
let mut state = [ZERO; STATE_WIDTH];
|
||||
|
||||
for i in 0..HALF_RATE_WIDTH {
|
||||
state[RATE_START + i] += seed[i];
|
||||
}
|
||||
|
||||
// Absorb
|
||||
Rpx256::apply_permutation(&mut state);
|
||||
|
||||
RpxRandomCoin { state, current: RATE_START }
|
||||
}
|
||||
|
||||
/// Returns an [RpxRandomCoin] instantiated from the provided components.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
|
||||
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
|
||||
assert!(
|
||||
(RATE_START..RATE_END).contains(¤t),
|
||||
"current value outside of valid range"
|
||||
);
|
||||
Self { state, current }
|
||||
}
|
||||
|
||||
/// Returns components of this random coin.
|
||||
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
|
||||
(self.state, self.current)
|
||||
}
|
||||
|
||||
/// Fills `dest` with random data.
|
||||
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
<Self as RngCore>::fill_bytes(self, dest)
|
||||
}
|
||||
|
||||
fn draw_basefield(&mut self) -> Felt {
|
||||
if self.current == RATE_END {
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
self.current = RATE_START;
|
||||
}
|
||||
|
||||
self.current += 1;
|
||||
self.state[self.current - 1]
|
||||
}
|
||||
}
|
||||
|
||||
// RANDOM COIN IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RandomCoin for RpxRandomCoin {
|
||||
type BaseField = Felt;
|
||||
type Hasher = Rpx256;
|
||||
|
||||
fn new(seed: &[Self::BaseField]) -> Self {
|
||||
let digest: Word = Rpx256::hash_elements(seed).into();
|
||||
Self::new(digest)
|
||||
}
|
||||
|
||||
fn reseed(&mut self, data: RpxDigest) {
|
||||
// Reset buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// Add the new seed material to the first half of the rate portion of the RPX state
|
||||
let data: Word = data.into();
|
||||
|
||||
self.state[RATE_START] += data[0];
|
||||
self.state[RATE_START + 1] += data[1];
|
||||
self.state[RATE_START + 2] += data[2];
|
||||
self.state[RATE_START + 3] += data[3];
|
||||
|
||||
// Absorb
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
}
|
||||
|
||||
fn check_leading_zeros(&self, value: u64) -> u32 {
|
||||
let value = Felt::new(value);
|
||||
let mut state_tmp = self.state;
|
||||
|
||||
state_tmp[RATE_START] += value;
|
||||
|
||||
Rpx256::apply_permutation(&mut state_tmp);
|
||||
|
||||
let first_rate_element = state_tmp[RATE_START].as_int();
|
||||
first_rate_element.trailing_zeros()
|
||||
}
|
||||
|
||||
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
|
||||
let ext_degree = E::EXTENSION_DEGREE;
|
||||
let mut result = vec![ZERO; ext_degree];
|
||||
for r in result.iter_mut().take(ext_degree) {
|
||||
*r = self.draw_basefield();
|
||||
}
|
||||
|
||||
let result = E::slice_from_base_elements(&result);
|
||||
Ok(result[0])
|
||||
}
|
||||
|
||||
fn draw_integers(
|
||||
&mut self,
|
||||
num_values: usize,
|
||||
domain_size: usize,
|
||||
nonce: u64,
|
||||
) -> Result<Vec<usize>, RandomCoinError> {
|
||||
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
|
||||
assert!(num_values < domain_size, "number of values must be smaller than domain size");
|
||||
|
||||
// absorb the nonce
|
||||
let nonce = Felt::new(nonce);
|
||||
self.state[RATE_START] += nonce;
|
||||
Rpx256::apply_permutation(&mut self.state);
|
||||
|
||||
// reset the buffer
|
||||
self.current = RATE_START;
|
||||
|
||||
// determine how many bits are needed to represent valid values in the domain
|
||||
let v_mask = (domain_size - 1) as u64;
|
||||
|
||||
// draw values from PRNG until we get as many unique values as specified by num_queries
|
||||
let mut values = Vec::new();
|
||||
for _ in 0..1000 {
|
||||
// get the next pseudo-random field element
|
||||
let value = self.draw_basefield().as_int();
|
||||
|
||||
// use the mask to get a value within the range
|
||||
let value = (value & v_mask) as usize;
|
||||
|
||||
values.push(value);
|
||||
if values.len() == num_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if values.len() < num_values {
|
||||
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
// FELT RNG IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl FeltRng for RpxRandomCoin {
|
||||
fn draw_element(&mut self) -> Felt {
|
||||
self.draw_basefield()
|
||||
}
|
||||
|
||||
fn draw_word(&mut self) -> Word {
|
||||
let mut output = [ZERO; 4];
|
||||
for o in output.iter_mut() {
|
||||
*o = self.draw_basefield();
|
||||
}
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
// RNGCORE IMPLEMENTATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl RngCore for RpxRandomCoin {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
self.draw_basefield().as_int() as u32
|
||||
}
|
||||
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
impls::next_u64_via_u32(self)
|
||||
}
|
||||
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
impls::fill_bytes_via_next(self, dest)
|
||||
}
|
||||
|
||||
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
|
||||
self.fill_bytes(dest);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// SERIALIZATION
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
impl Serializable for RpxRandomCoin {
|
||||
fn write_into<W: ByteWriter>(&self, target: &mut W) {
|
||||
self.state.iter().for_each(|v| v.write_into(target));
|
||||
// casting to u8 is OK because `current` is always between 4 and 12.
|
||||
target.write_u8(self.current as u8);
|
||||
}
|
||||
}
|
||||
|
||||
impl Deserializable for RpxRandomCoin {
|
||||
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
|
||||
let state = [
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
Felt::read_from(source)?,
|
||||
];
|
||||
let current = source.read_u8()? as usize;
|
||||
if !(RATE_START..RATE_END).contains(¤t) {
|
||||
return Err(DeserializationError::InvalidValue(
|
||||
"current value outside of valid range".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(Self { state, current })
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Deserializable, FeltRng, RpxRandomCoin, Serializable, ZERO};
|
||||
use crate::ONE;
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_felt() {
|
||||
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let output = rpxcoin.draw_element();
|
||||
|
||||
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let expected = rpxcoin.draw_basefield();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_word() {
|
||||
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let output = rpxcoin.draw_word();
|
||||
|
||||
let mut rpocoin = RpxRandomCoin::new([ZERO; 4]);
|
||||
let mut expected = [ZERO; 4];
|
||||
for o in expected.iter_mut() {
|
||||
*o = rpocoin.draw_basefield();
|
||||
}
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feltrng_serialization() {
|
||||
let coin1 = RpxRandomCoin::from_parts([ONE; 12], 5);
|
||||
|
||||
let bytes = coin1.to_bytes();
|
||||
let coin2 = RpxRandomCoin::read_from_bytes(&bytes).unwrap();
|
||||
assert_eq!(coin1, coin2);
|
||||
}
|
||||
}
|
402
src/utils/kv_map.rs
Normal file
402
src/utils/kv_map.rs
Normal file
|
@ -0,0 +1,402 @@
|
|||
use alloc::{
|
||||
boxed::Box,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
};
|
||||
use core::cell::RefCell;
|
||||
|
||||
// KEY-VALUE MAP TRAIT
|
||||
// ================================================================================================
|
||||
|
||||
/// A trait that defines the interface for a key-value map.
|
||||
pub trait KvMap<K: Ord + Clone, V: Clone>:
|
||||
Extend<(K, V)> + FromIterator<(K, V)> + IntoIterator<Item = (K, V)>
|
||||
{
|
||||
fn get(&self, key: &K) -> Option<&V>;
|
||||
fn contains_key(&self, key: &K) -> bool;
|
||||
fn len(&self) -> usize;
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V>;
|
||||
fn remove(&mut self, key: &K) -> Option<V>;
|
||||
|
||||
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_>;
|
||||
}
|
||||
|
||||
// BTREE MAP `KvMap` IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
|
||||
impl<K: Ord + Clone, V: Clone> KvMap<K, V> for BTreeMap<K, V> {
|
||||
fn get(&self, key: &K) -> Option<&V> {
|
||||
self.get(key)
|
||||
}
|
||||
|
||||
fn contains_key(&self, key: &K) -> bool {
|
||||
self.contains_key(key)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
self.insert(key, value)
|
||||
}
|
||||
|
||||
fn remove(&mut self, key: &K) -> Option<V> {
|
||||
self.remove(key)
|
||||
}
|
||||
|
||||
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
// RECORDING MAP
|
||||
// ================================================================================================
|
||||
|
||||
/// A [RecordingMap] that records read requests to the underlying key-value map.
|
||||
///
|
||||
/// The data recorder is used to generate a proof for read requests.
|
||||
///
|
||||
/// The [RecordingMap] is composed of three parts:
|
||||
/// - `data`: which contains the current set of key-value pairs in the map.
|
||||
/// - `updates`: which tracks keys for which values have been changed since the map was
|
||||
/// instantiated. updates include both insertions, removals and updates of values under existing
|
||||
/// keys.
|
||||
/// - `trace`: which contains the key-value pairs from the original data which have been accesses
|
||||
/// since the map was instantiated.
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
pub struct RecordingMap<K, V> {
|
||||
data: BTreeMap<K, V>,
|
||||
updates: BTreeSet<K>,
|
||||
trace: RefCell<BTreeMap<K, V>>,
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone> RecordingMap<K, V> {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns a new [RecordingMap] instance initialized with the provided key-value pairs.
|
||||
/// ([BTreeMap]).
|
||||
pub fn new(init: impl IntoIterator<Item = (K, V)>) -> Self {
|
||||
RecordingMap {
|
||||
data: init.into_iter().collect(),
|
||||
updates: BTreeSet::new(),
|
||||
trace: RefCell::new(BTreeMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
pub fn inner(&self) -> &BTreeMap<K, V> {
|
||||
&self.data
|
||||
}
|
||||
|
||||
// FINALIZER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first
|
||||
/// element of the tuple is a map that represents the state of the map at the time `.finalize()`
|
||||
/// is called. The second element contains the key-value pairs from the initial data set that
|
||||
/// were read during recording.
|
||||
pub fn finalize(self) -> (BTreeMap<K, V>, BTreeMap<K, V>) {
|
||||
(self.data, self.trace.take())
|
||||
}
|
||||
|
||||
// TEST HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn trace_len(&self) -> usize {
|
||||
self.trace.borrow().len()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn updates_len(&self) -> usize {
|
||||
self.updates.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns a reference to the value associated with the given key if the value exists.
|
||||
///
|
||||
/// If the key is part of the initial data set, the key access is recorded.
|
||||
fn get(&self, key: &K) -> Option<&V> {
|
||||
self.data.get(key).inspect(|&value| {
|
||||
if !self.updates.contains(key) {
|
||||
self.trace.borrow_mut().insert(key.clone(), value.clone());
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a boolean to indicate whether the given key exists in the data set.
|
||||
///
|
||||
/// If the key is part of the initial data set, the key access is recorded.
|
||||
fn contains_key(&self, key: &K) -> bool {
|
||||
self.get(key).is_some()
|
||||
}
|
||||
|
||||
/// Returns the number of key-value pairs in the data set.
|
||||
fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
// MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Inserts a key-value pair into the data set.
|
||||
///
|
||||
/// If the key already exists in the data set, the value is updated and the old value is
|
||||
/// returned.
|
||||
fn insert(&mut self, key: K, value: V) -> Option<V> {
|
||||
let new_update = self.updates.insert(key.clone());
|
||||
self.data.insert(key.clone(), value).inspect(|old_value| {
|
||||
if new_update {
|
||||
self.trace.borrow_mut().insert(key, old_value.clone());
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Removes a key-value pair from the data set.
|
||||
///
|
||||
/// If the key exists in the data set, the old value is returned.
|
||||
fn remove(&mut self, key: &K) -> Option<V> {
|
||||
self.data.remove(key).inspect(|old_value| {
|
||||
let new_update = self.updates.insert(key.clone());
|
||||
if new_update {
|
||||
self.trace.borrow_mut().insert(key.clone(), old_value.clone());
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ITERATION
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Returns an iterator over the key-value pairs in the data set.
|
||||
fn iter(&self) -> Box<dyn Iterator<Item = (&K, &V)> + '_> {
|
||||
Box::new(self.data.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> Extend<(K, V)> for RecordingMap<K, V> {
|
||||
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
|
||||
iter.into_iter().for_each(move |(k, v)| {
|
||||
self.insert(k, v);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> FromIterator<(K, V)> for RecordingMap<K, V> {
|
||||
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
|
||||
Self::new(iter)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
|
||||
type Item = (K, V);
|
||||
type IntoIter = alloc::collections::btree_map::IntoIter<K, V>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.data.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
// TESTS
|
||||
// ================================================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const ITEMS: [(u64, u64); 5] = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)];
|
||||
|
||||
#[test]
|
||||
fn test_get_item() {
|
||||
// instantiate a recording map
|
||||
let map = RecordingMap::new(ITEMS.to_vec());
|
||||
|
||||
// get a few items
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.get(key);
|
||||
}
|
||||
|
||||
// convert the map into a proof
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, value) in ITEMS.iter() {
|
||||
match get_items.contains(key) {
|
||||
true => assert_eq!(proof.get(key), Some(value)),
|
||||
false => assert_eq!(proof.get(key), None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contains_key() {
|
||||
// instantiate a recording map
|
||||
let map = RecordingMap::new(ITEMS.to_vec());
|
||||
|
||||
// check if the map contains a few items
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.contains_key(key);
|
||||
}
|
||||
|
||||
// convert the map into a proof
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, _) in ITEMS.iter() {
|
||||
match get_items.contains(key) {
|
||||
true => assert!(proof.contains_key(key)),
|
||||
false => assert!(!proof.contains_key(key)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_len() {
|
||||
// instantiate a recording map
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
// length of the map should be equal to the number of items
|
||||
assert_eq!(map.len(), ITEMS.len());
|
||||
|
||||
// inserting entry with key that already exists should not change the length, but it does
|
||||
// add entries to the trace and update sets
|
||||
map.insert(4, 5);
|
||||
assert_eq!(map.len(), ITEMS.len());
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// inserting entry with new key should increase the length; it should also record the key
|
||||
// as an updated key, but the trace length does not change since old values were not touched
|
||||
map.insert(5, 5);
|
||||
assert_eq!(map.len(), ITEMS.len() + 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// get some items so that they are saved in the trace; this should record original items
|
||||
// in the trace, but should not affect the set of updates
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.contains_key(key);
|
||||
}
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// read the same items again, this should not have any effect on either length, trace, or
|
||||
// the set of updates
|
||||
let get_items = [0, 1, 2];
|
||||
for key in get_items.iter() {
|
||||
map.contains_key(key);
|
||||
}
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// read a newly inserted item; this should not affect either length, trace, or the set of
|
||||
// updates
|
||||
let _val = map.get(&5).unwrap();
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// update a newly inserted item; this should not affect either length, trace, or the set
|
||||
// of updates
|
||||
map.insert(5, 11);
|
||||
assert_eq!(map.trace_len(), 4);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// Note: The length reported by the proof will be different to the length originally
|
||||
// reported by the map.
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// length of the proof should be equal to get_items + 1. The extra item is the original
|
||||
// value at key = 4u64
|
||||
assert_eq!(proof.len(), get_items.len() + 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iter() {
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y))));
|
||||
|
||||
// when inserting entry with key that already exists the iterator should return the new
|
||||
// value
|
||||
let new_value = 5;
|
||||
map.insert(4, new_value);
|
||||
assert_eq!(map.iter().count(), ITEMS.len());
|
||||
assert!(map.iter().all(|(x, y)| if x == &4 {
|
||||
y == &new_value
|
||||
} else {
|
||||
ITEMS.contains(&(*x, *y))
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_empty() {
|
||||
// instantiate an empty recording map
|
||||
let empty_map: RecordingMap<u64, u64> = RecordingMap::default();
|
||||
assert!(empty_map.is_empty());
|
||||
|
||||
// instantiate a non-empty recording map
|
||||
let map = RecordingMap::new(ITEMS.to_vec());
|
||||
assert!(!map.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
let mut map = RecordingMap::new(ITEMS.to_vec());
|
||||
|
||||
// remove an item that exists
|
||||
let key = 0;
|
||||
let value = map.remove(&key).unwrap();
|
||||
assert_eq!(value, ITEMS[0].1);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// add the item back and then remove it again
|
||||
let key = 0;
|
||||
let value = 0;
|
||||
map.insert(key, value);
|
||||
let value = map.remove(&key).unwrap();
|
||||
assert_eq!(value, 0);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// remove an item that does not exist
|
||||
let key = 100;
|
||||
let value = map.remove(&key);
|
||||
assert_eq!(value, None);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 1);
|
||||
|
||||
// insert a new item and then remove it
|
||||
let key = 100;
|
||||
let value = 100;
|
||||
map.insert(key, value);
|
||||
let value = map.remove(&key).unwrap();
|
||||
assert_eq!(value, 100);
|
||||
assert_eq!(map.len(), ITEMS.len() - 1);
|
||||
assert_eq!(map.trace_len(), 1);
|
||||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// convert the map into a proof
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, value) in ITEMS.iter() {
|
||||
match key {
|
||||
0 => assert_eq!(proof.get(key), Some(value)),
|
||||
_ => assert_eq!(proof.get(key), None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue