1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge branch 'develop' of github.com:Zokrates/ZoKrates into marlin

This commit is contained in:
schaeff 2021-06-22 17:05:21 +02:00
commit 91bb00cf50
246 changed files with 8734 additions and 5428 deletions

View file

@ -1,4 +1,12 @@
version: 2
version: 2.1
executors:
linux:
machine:
image: ubuntu-2004:202101-01
macos:
macos:
xcode: 12.4.0
jobs:
build:
@ -47,9 +55,6 @@ jobs:
- run:
name: Run clippy
command: cargo clippy -- -D warnings
- run:
name: Build
command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./build.sh
- run:
name: Run tests
command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./test.sh
@ -80,7 +85,9 @@ jobs:
- v4-cargo-cache-{{ arch }}-{{ checksum "Cargo.lock" }}
- run:
name: Test on firefox
command: cd zokrates_core && wasm-pack test --firefox --headless -- --no-default-features --features "wasm bellman"
command: |
cd zokrates_core
wasm-pack test --firefox --headless -- --no-default-features --features "wasm bellman"
integration_test:
docker:
- image: zokrates/env:latest
@ -109,6 +116,7 @@ jobs:
docker_layer_caching: true
- run:
name: Release
no_output_timeout: "30m"
command: ./scripts/release.sh
zokrates_js_build:
docker:
@ -121,12 +129,101 @@ jobs:
zokrates_js_test:
docker:
- image: zokrates/env:latest
working_directory: ~/project/zokrates_js
steps:
- checkout:
path: ~/project
- run:
name: Check format
command: cargo fmt --all -- --check
- run:
name: Run clippy
command: cargo clippy -- -D warnings
- run:
name: Run tests
command: npm run test
cross_build:
parameters:
os:
type: executor
target:
type: string
executor: << parameters.os >>
steps:
- checkout
- run:
command: cd zokrates_js && npm run test
name: Calculate dependencies
command: cargo generate-lockfile
- run:
no_output_timeout: "30m"
command: cross build --target << parameters.target >> --release
- tar_artifacts:
target: << parameters.target >>
publish_artifacts:
docker:
- image: circleci/golang
steps:
- attach_workspace:
at: /tmp/artifacts
- run:
name: "Publish artifacts on GitHub"
command: |
go get github.com/github-release/github-release
github-release release \
-s ${GH_TOKEN} \
-u ${CIRCLE_PROJECT_USERNAME} \
-r ${CIRCLE_PROJECT_REPONAME} \
-t ${CIRCLE_TAG} || true
find /tmp/artifacts -type f -name *.tar.gz -exec basename {} \; | xargs -I {} github-release upload \
-s ${GH_TOKEN} \
-u ${CIRCLE_PROJECT_USERNAME} \
-r ${CIRCLE_PROJECT_REPONAME} \
-t ${CIRCLE_TAG} \
-n "{}" \
-f /tmp/artifacts/{}
commands:
install_rust:
steps:
- run:
name: Install Rust
command: |
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source $HOME/.cargo/env
install_cross:
steps:
- run:
name: Install rust-embedded/cross
command: cargo install --git https://github.com/rust-embedded/cross
tar_artifacts:
parameters:
target:
type: string
steps:
- run:
name: Store build artifacts
command: |
mkdir -p /tmp/artifacts
find target/<< parameters.target >>/release -maxdepth 1 -type f | grep -E "zokrates(\.exe)?$" | xargs -I {} cp {} /tmp/artifacts/
cp -r zokrates_stdlib/stdlib /tmp/artifacts/
cd /tmp/artifacts
tar czf zokrates-${CIRCLE_TAG}-<< parameters.target >>.tar.gz *
ls | grep -v *.tar.gz | xargs rm -rf
- store_artifacts:
path: /tmp/artifacts
- persist_to_workspace:
root: /tmp/artifacts
paths:
- zokrates-*-<< parameters.target >>.tar.gz
tag-only: &tag-only
filters:
branches:
ignore: /.*/
tags:
only: /^\d+\.\d+\.\d+$/
workflows:
version: 2
build-test-and-deploy:
jobs:
- build
@ -136,6 +233,38 @@ workflows:
- integration_test
- zokrates_js_build
- zokrates_js_test
- cross_build:
<<: *tag-only
pre-steps:
- install_rust
- install_cross
matrix:
alias: cross-build-linux
parameters:
os:
- linux
target:
- aarch64-unknown-linux-gnu
- arm-unknown-linux-gnueabi
- x86_64-unknown-linux-gnu
- x86_64-pc-windows-gnu
- cross_build:
<<: *tag-only
pre-steps:
- install_rust
- install_cross
matrix:
alias: cross-build-macos
parameters:
os:
- macos
target:
- x86_64-apple-darwin
- publish_artifacts:
<<: *tag-only
requires:
- cross-build-linux
- cross-build-macos
- deploy:
filters:
branches:
@ -148,4 +277,4 @@ workflows:
- wasm_test
- integration_test
- zokrates_js_build
- zokrates_js_test
- zokrates_js_test

11
.github/workflows/docs-check.yml vendored Normal file
View file

@ -0,0 +1,11 @@
name: Check markdown links
on: [pull_request]
jobs:
markdown-link-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: gaurav-nelson/github-action-markdown-link-check@v1
with:
use-quiet-mode: 'no'
use-verbose-mode: 'yes'

View file

@ -1,75 +0,0 @@
# Based on the "trust" template v0.1.2
# https://github.com/japaric/trust/tree/v0.1.2
dist: trusty
language: rust
rust:
- nightly
services: docker
sudo: required
env:
global:
- CRATE_NAME=zokrates
matrix:
include:
# Linux
- env: TARGET=aarch64-unknown-linux-gnu
- env: TARGET=arm-unknown-linux-gnueabi
#- env: TARGET=i686-unknown-linux-gnu
- env: TARGET=x86_64-unknown-linux-gnu
# OSX
# - env: TARGET=i686-apple-darwin
# os: osx
- env: TARGET=x86_64-apple-darwin
os: osx
# *BSD
# - env: TARGET=x86_64-unknown-freebsd
# Windows
- env: TARGET=x86_64-pc-windows-gnu
before_install:
- set -e
- rustup self update
install:
- sh ci/install.sh
- source ~/.cargo/env || true
script:
- bash ci/script.sh
after_script: set +e
before_deploy:
- sh ci/before_deploy.sh
deploy:
api_key:
secure: cpo6ukDxL+h6Dw2A4eVaC0ddU/zguuds2yhHp1UE0DUWo/lpBNtg3bw51o/GrX8JyTWJCUMLZOKJyoyUWiht41BtlqRl5Egp/ugFEfCoPS+J6u0BIBEULwXrvOmxxF+K+DLH1MX179z1R2SYBYcm8V7GvygzAwaSP4fRq3Uwqr2l3lc6Q+V2kQ0Hylmeguaqhj9lG5BQA/fG1qlWdUMTHMInCCnb2z7SP3/kWEhkdCavCWtRjaeKoWSgKDcB/UUVVnRwnq5dE76DTJU6wEqG4njityxPTTZ+u0a5FiFnUhmqtqszicAs3jAKAcekyeM0B2prTF/xPGsPqLnce4ljoSK93VU08Ut1bJNMyfRLBzd/jEwOCp6ADUQnCTDxUP4Z2iK0EGya2ciXnZi/sCwPJZPV8uqUnfHdHYOdky1+64MJE1tBgC9ZaTcLFsATD6KkffKa2rmqgZCZNeHITs6HOGZhatw6u0eLknNqqBkQIMKvGRLjI6kZxDA2HsMYNTHPevUOKu68Kebi3aQG3H3OODXO3cKvGGoPHFx4uf3E5Gn4GJEePQqC1r5zYpdrQyOEN3VyLRZVHlAR/Kzm+5mameP4CyT8ppfLfQhy+sl6OfAV6X0Ap96gbWWj0I6w0CrZ10VLgJD2W5sllyiBnsNzkccW3Yg9DCuf75/ydme/JCc=
file_glob: true
file: $CRATE_NAME-$TRAVIS_TAG-$TARGET.*
on:
tags: true
provider: releases
skip_cleanup: true
cache: cargo
before_cache:
# Travis can't cache files that are not readable by "others"
- chmod -R a+r $HOME/.cargo
branches:
only:
# release tags
- /^\d+\.\d+\.\d+.*$/
- deploy
notifications:
email:
on_success: never

View file

@ -4,6 +4,73 @@ All notable changes to this project will be documented in this file.
## [Unreleased]
https://github.com/Zokrates/ZoKrates/compare/latest...develop
## [0.7.4] - 2021-06-17
### Release
- https://github.com/Zokrates/ZoKrates/releases/tag/0.7.4 <!-- markdown-link-check-disable-line -->
### Changes
- Add `FIELD_SIZE_IN_BITS`, `FIELD_MIN` and `FIELD_MAX` constants to `field` stdlib module (#917, @dark64)
- Fix crash on import of functions containing constants (#913, @schaeff)
- Change endianness in keccak, sha3 and blake2s hash algorithms to big endian (#906, @dark64)
- Documentation improvements, move examples to a separate section, remove deprecated `--light` flag used in a rng tutorial, add a simple file system resolver example to zokrates.js docs (#914, @dark64)
- Fixed deserialization logic in the zokrates.js that caused issues on cli-compiled binaries (#912, @dark64)
- Reduce the cost of conditionals (#907, @schaeff)
- Improve propagation on if-else expressions when consequence and alternative are equal (#905, @schaeff)
- Fix access to constant in local function call (#910, @schaeff)
- Fix parsing of the left hand side of definitions (#896, @schaeff)
- Fix variable write remover when isolating branches (#904, @schaeff)
- Introduce a limit of 2**20 for for-loop sizes (#902, @schaeff)
- Run compilation test on RNG tutorial and fix bugs (#881, @axic)
## [0.7.3] - 2021-05-19
### Release
- https://github.com/Zokrates/ZoKrates/releases/tag/0.7.3
### Changes
- Remove substitution in `one_liner.sh` script which caused `Bad substitution` error with `sh`/`dash` (#877, @dark64)
- Put branch isolator behind a compilation flag in the static analyzer (#877, @dark64)
## [0.7.2] - 2021-05-18
### Release
- https://github.com/Zokrates/ZoKrates/releases/tag/0.7.2
### Changes
- Isolate branch panics: only panic in a branch if it's being logically executed (#865, @schaeff)
- Support the use of constants in struct and function declarations (#864, @dark64)
- Relax ordering of symbol declarations (#863, @dark64)
- Update `one_liner.sh` script to support arm64 architecture (#861, @dark64)
- Fix crash when updating a constant struct member to another constant (#855, @schaeff)
- Fix treatment of uint subtraction involving constants (bug) (#852, @schaeff)
- Add uint to abi docs (#848, @schaeff)
- Remove side effects on complex types (bug) (#847, @schaeff)
- Fix crash on struct member type mismatch (#846, @schaeff)
- Fix nested struct access crash (#845, @schaeff)
- Make error formatting consistent (#843, @schaeff)
## [0.7.1] - 2021-04-30
### Release
- https://github.com/Zokrates/ZoKrates/releases/tag/0.7.1
### Changes
- Fix integer inference on repeat operators (#834, @schaeff)
- Introduce constant definitions to the language (`const` keyword) (#792, @dark64)
- Introduce constant range checks for checks of the form `x < c` where `p` is a compile-time constant, also for other comparison operators. This works for any `x` and `p`, unlike dynamic `x < y` comparison (#761, @schaeff)
- Handle errors more gracefully in propagation step where applicable (#832, @dark64)
- Add interactive prompt before overwriting existing files in the `one_liner.sh` script (#831, @dark64)
- Add a custom panic hook to handle internal compiler errors more gracefully (#829, @dark64)
- Make command line errors compatible with editor cmd+click (#828, @schaeff)
- Make function definitions more permissive, and move ambiguity checks to call sites and improve them (#826, @schaeff)
- Detect assertion failures at compile time on constant expressions (#823, @dark64)
- Make function selection stricter in function calls (#822, @schaeff)
- Add the ability to import multiple symbols in a single import statement (#809, @dark64)
- Add [poseidon](https://www.poseidon-hash.info/) zk-friendly hashing algorithm to stdlib (#806, @dark64)
- Allow optional underscore before type suffix (e.g. `42_u32`) (#800, @dark64)
- Accept explicit generic parameters outside of definitions (#798, @schaeff)
## [0.7.0] - 2021-04-09
### Release

16
Cargo.lock generated
View file

@ -2284,7 +2284,7 @@ dependencies = [
[[package]]
name = "zokrates_cli"
version = "0.7.0"
version = "0.7.4"
dependencies = [
"assert_cli",
"bincode",
@ -2309,7 +2309,7 @@ version = "0.1.0"
[[package]]
name = "zokrates_core"
version = "0.6.0"
version = "0.6.4"
dependencies = [
"ark-bls12-377",
"ark-bn254",
@ -2354,7 +2354,7 @@ dependencies = [
[[package]]
name = "zokrates_core_test"
version = "0.2.0"
version = "0.2.2"
dependencies = [
"zokrates_test",
"zokrates_test_derive",
@ -2362,7 +2362,7 @@ dependencies = [
[[package]]
name = "zokrates_embed"
version = "0.1.2"
version = "0.1.3"
dependencies = [
"bellman_ce",
"sapling-crypto_ce",
@ -2400,7 +2400,7 @@ dependencies = [
[[package]]
name = "zokrates_parser"
version = "0.2.0"
version = "0.2.2"
dependencies = [
"glob 0.2.11",
"pest",
@ -2409,7 +2409,7 @@ dependencies = [
[[package]]
name = "zokrates_pest_ast"
version = "0.2.0"
version = "0.2.2"
dependencies = [
"from-pest",
"glob 0.2.11",
@ -2421,7 +2421,7 @@ dependencies = [
[[package]]
name = "zokrates_stdlib"
version = "0.2.0"
version = "0.2.3"
dependencies = [
"fs_extra",
"zokrates_test",
@ -2430,7 +2430,7 @@ dependencies = [
[[package]]
name = "zokrates_test"
version = "0.1.5"
version = "0.1.6"
dependencies = [
"serde",
"serde_derive",

View file

@ -1,5 +1,5 @@
<img src="http://www.redaktion.tu-berlin.de/fileadmin/fg308/icons/projekte/logos/ZoKrates_logo.svg" width="100%" height="180">
<img src="zokrates_logo.svg" width="100%" height="180">
# ZoKrates
@ -21,7 +21,7 @@ curl -LSfs get.zokrat.es | sh
```
Have a look at the [documentation](https://zokrates.github.io/) for more information about using ZoKrates.
A getting started tutorial can be found [here](https://zokrates.github.io/sha256example.html).
[Get started](https://zokrates.github.io/gettingstarted.html), then try a [tutorial](https://zokrates.github.io/rng_tutorial.html)!
## Getting Help
@ -43,4 +43,4 @@ You can enable zokrates git hooks locally by running:
```sh
git config core.hooksPath .githooks
```
```

View file

@ -4,7 +4,7 @@
set -e
if [ -n "$WITH_LIBSNARK" ]; then
cargo -Z package-features build --package zokrates_cli --features="libsnark"
cargo build --package zokrates_cli --features="libsnark"
else
cargo build
fi

View file

@ -4,7 +4,7 @@
set -e
if [ -n "$WITH_LIBSNARK" ]; then
cargo -Z package-features build --release --package zokrates_cli --features="libsnark"
cargo build --release --package zokrates_cli --features="libsnark"
else
cargo build --release
fi

View file

@ -1 +0,0 @@
Accept explicit generic parameters outside of definitions

View file

@ -1 +0,0 @@
Allow optional underscore before type suffix (e.g. `42_u32`)

View file

@ -1 +0,0 @@
Add [poseidon](https://www.poseidon-hash.info/) zk-friendly hashing algorithm to stdlib

View file

@ -1 +0,0 @@
Add the ability to import multiple symbols in a single import statement

View file

@ -1 +0,0 @@
Make function selection stricter in function calls

View file

@ -1 +0,0 @@
Detect assertion failures at compile time on constant expressions

View file

@ -1 +0,0 @@
Make function definitions more permissive, and move ambiguity checks to call sites and improve them

View file

@ -1 +0,0 @@
Make command line errors compatible with editor cmd+click

View file

@ -1 +0,0 @@
Add a custom panic hook to handle internal compiler errors more gracefully

View file

@ -1 +0,0 @@
Add interactive prompt before overwriting existing files in the `one_liner.sh` script

View file

@ -1 +0,0 @@
Handle errors more gracefully in propagation step where applicable

View file

@ -1,53 +0,0 @@
#!/bin/bash
# This script takes care of building your crate and packaging it for release
set -ex
main() {
local src=$(pwd) \
stage=
case $TRAVIS_OS_NAME in
linux)
stage=$(mktemp -d)
;;
osx)
stage=$(mktemp -d -t tmp)
;;
esac
case $TARGET in
x86_64-pc-windows-gnu)
BINARY_NAME=zokrates.exe
;;
*)
BINARY_NAME=zokrates
;;
esac
test -f Cargo.lock || cargo generate-lockfile
case $TRAVIS_OS_NAME in
linux)
cross build --bin zokrates --package zokrates_cli --features="libsnark" --target $TARGET --release
;;
*)
cross build --bin zokrates --package zokrates_cli --target $TARGET --release
;;
esac
# Package artifacts
# Binary
cp target/$TARGET/release/$BINARY_NAME $stage/
# Standard library
cp -r zokrates_stdlib/stdlib $stage
cd $stage
tar czf $src/$CRATE_NAME-$TRAVIS_TAG-$TARGET.tar.gz *
cd $src
rm -rf $stage
}
main

View file

@ -1,49 +0,0 @@
#!/bin/bash
set -ex
main() {
local target=
if [ $TRAVIS_OS_NAME = linux ]; then
target=x86_64-unknown-linux-musl
sort=sort
else
target=x86_64-apple-darwin
sort=gsort # for `sort --sort-version`, from brew's coreutils.
fi
# Builds for iOS are done on OSX, but require the specific target to be
# installed.
case $TARGET in
aarch64-apple-ios)
rustup target install aarch64-apple-ios
;;
armv7-apple-ios)
rustup target install armv7-apple-ios
;;
armv7s-apple-ios)
rustup target install armv7s-apple-ios
;;
i386-apple-ios)
rustup target install i386-apple-ios
;;
x86_64-apple-ios)
rustup target install x86_64-apple-ios
;;
esac
# This fetches latest stable release
local tag=$(git ls-remote --tags --refs --exit-code https://github.com/japaric/cross \
| cut -d/ -f3 \
| grep -E '^v[0.1.0-9.]+$' \
| $sort --version-sort \
| tail -n1)
curl -LSfs https://japaric.github.io/trust/install.sh | \
sh -s -- \
--force \
--git japaric/cross \
--tag $tag \
--target $target
}
main

View file

@ -1,13 +0,0 @@
#!/bin/bash
# This script takes care of testing your crate
set -ex
# This is the test phase. We will only build if tests happened before.
main() {
cross build --target $TARGET
cross build --target $TARGET --release
}
main

View file

@ -4,7 +4,7 @@
set -e
if [ -n "$WITH_LIBSNARK" ]; then
cargo -Z package-features test --release --package zokrates_cli --features="libsnark" -- --ignored --test-threads=1
cargo test --release --package zokrates_cli --features="libsnark" -- --ignored
else
cargo test --release -- --ignored --test-threads=1
cargo test --release -- --ignored
fi

1
rust-toolchain Normal file
View file

@ -0,0 +1 @@
nightly-2021-04-25

2
rust-toolchain.toml Normal file
View file

@ -0,0 +1,2 @@
[toolchain]
channel = "nightly-2021-04-25"

View file

@ -30,7 +30,7 @@ cat << EOT
## [${tag}] - $(qdate '+%Y-%m-%d')
### Release
- https://github.com/Zokrates/ZoKrates/releases/tag/${tag}
- https://github.com/Zokrates/ZoKrates/releases/tag/${tag} <!-- markdown-link-check-disable-line -->
### Changes
EOT
@ -38,9 +38,13 @@ EOT
for file in $unreleased
do
IFS=$'-' read -ra entry <<< "$file"
contents=$(cat ${CHANGELOG_PATH}/${file} | tr '\n' ' ')
author=$(join '-' ${entry[@]:1})
echo "- ${contents} (#${entry[0]}, @${author})"
IFS=$'\n' rows=$(cat ${CHANGELOG_PATH}/${file})
for row in $rows
do
echo "- ${row} (#${entry[0]}, @${author})"
done
done
echo -e "\nCopy and paste the markdown above to the appropriate CHANGELOG file."

View file

@ -150,7 +150,7 @@ get_architecture() {
fi
;;
aarch64)
aarch64 | arm64)
_cputype=aarch64
;;
@ -297,7 +297,7 @@ main() {
cp -r $td/* $dest
else
read -p "ZoKrates is already installed, overwrite (y/n)? " answer
case ${answer:0:1} in
case ${answer} in
y|Y )
rm -rf $dest/*
cp -r $td/* $dest

View file

@ -4,7 +4,7 @@
set -e
if [ -n "$WITH_LIBSNARK" ]; then
cargo -Z package-features test --release --package zokrates_cli --features="libsnark" -- --test-threads=1
cargo test --release --package zokrates_cli --features="libsnark"
else
cargo test --release -- --test-threads=1
cargo test --release
fi

View file

@ -8,11 +8,12 @@
- [Variables](language/variables.md)
- [Types](language/types.md)
- [Operators](language/operators.md)
- [Functions](language/functions.md)
- [Control flow](language/control_flow.md)
- [Constants](language/constants.md)
- [Functions](language/functions.md)
- [Generics](language/generics.md)
- [Imports](language/imports.md)
- [Comments](language/comments.md)
- [Generics](language/generics.md)
- [Macros](language/macros.md)
- [Toolbox](toolbox/index.md)
@ -24,6 +25,8 @@
- [JSON ABI](toolbox/abi.md)
- [zokrates.js](toolbox/zokrates_js.md)
- [Tutorial: A zkSNARK RNG](rng_tutorial.md)
- [Examples](examples/index.md)
- [A SNARK Powered RNG](examples/rng_tutorial.md)
- [Proving knowledge of a hash preimage](examples/sha256example.md)
- [Testing](testing.md)

View file

@ -0,0 +1,6 @@
# ZoKrates Examples
This section covers examples of using the ZoKrates programming language.
- [A SNARK Powered RNG](./rng_tutorial.md)
- [Proving knowledge of a hash preimage](./sha256example.md)

View file

@ -2,7 +2,7 @@
## Prerequisites
Make sure you have followed the instructions in the [Getting Started](gettingstarted.md) chapter and are able to run the "Hello World" example described there.
Make sure you have followed the instructions in the [Getting Started](../gettingstarted.md) chapter and are able to run the "Hello World" example described there.
## Description of the problem
@ -25,24 +25,20 @@ The first step is for Alice and Bob to each come up with a preimage value and ca
There are many ways to calculate a hash, but here we use Zokrates.
1. Create this file under the name `get_hash.zok`:
```javascript
import "hashes/sha256/512bit" as sha256
def main(u32[16] hashMe) -> u32[8]:
u32[8] h = sha256(hashMe[0..8], hashMe[8..16])
return h
```zokrates
{{#include ../../zokrates_cli/examples/book/rng_tutorial/get_hash.zok}}
```
2. Compile the program to a form that is usable for zero knowledge proofs. This command writes
the binary to `get_hash`. You can see a textual representation, somewhat analogous to assembler
coming from a compiler, at `get_hash.ztf` if you remove the `--light` command line option.
coming from a compiler, at `get_hash.ztf` enabled by the `--ztf` command line option.
```
zokrates compile -i get_hash.zok -o get_hash --light
zokrates compile -i get_hash.zok -o get_hash --ztf
```
3. The input to the Zokrates program is sixteen 32 bit values, each in decimal. specify those values
to get a hash. For example, to calculate the hash of `0x00000000000000010000000200000003000000040000000500000006...`
use this command:
```
zokrates compute-witness --light -i get_hash -a 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
zokrates compute-witness --verbose -i get_hash -a 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
```
The result is:
```
@ -102,34 +98,14 @@ Finally, return `h` to the caller to display the hash.
The next step is to reveal a single bit.
1. Use this program, `reveal_bit.zok`:
```javascript
import "hashes/sha256/512bit" as sha256
import "utils/casts/u32_to_bits" as u32_to_bits
// Reveal a bit from a 512 bit value, and return it with the corresponding hash
// for that value.
//
// WARNING, once enough bits have been revealed it is possible to brute force
// the remaining preimage bits.
def main(private u32[16] preimage, field bitNum) -> (u32[8], bool):
// Convert the preimage to bits
bool[512] preimageBits = [false; 512]
for field i in 0..16 do
bool[32] val = u32_to_bits(preimage[i])
for field bit in 0..32 do
preimageBits[i*32+bit] = val[bit]
endfor
endfor
return sha256(preimage[0..8], preimage[8..16]), preimageBits[bitNum]
```zokrates
{{#include ../../zokrates_cli/examples/book/rng_tutorial/reveal_bit.zok}}
```
2. Compile and run as you did the previous program:
```bash
zokrates compile -i reveal_bit.zok -o reveal_bit --light
zokrates compute-witness --light -i reveal_bit -a 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 510
zokrates compile -i reveal_bit.zok -o reveal_bit
zokrates compute-witness --verbose -i reveal_bit -a 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 510
```
3. The output should be similar to:
```
@ -157,7 +133,7 @@ A Zokrates function can return multiple values. In this case, it returns the has
value of the bit being revealed.
```javascript
def main(private u32[16] preimage, field bitNum) -> (u32[8], bool):
def main(private u32[16] preimage, u32 bitNum) -> (u32[8], bool):
```
&nbsp;
@ -176,10 +152,10 @@ when it is declared.
&nbsp;
This is a [for loop](https://zokrates.github.io/language/control_flow.html#for-loops). For loops
have to have an index of type `field`, and their bounds need to be known at compile time.
have to have an index of type `u32`, and their bounds need to be known at compile time.
In this case, we go over each of the sixteen 32 bit words.
```javascript
for field i in 0..16 do
for u32 i in 0..16 do
```
The function we imported, `u32_to_bits`, converts a `u32` value to an array of bits.
@ -193,7 +169,7 @@ The function we imported, `u32_to_bits`, converts a `u32` value to an array of b
The inner loop copies the bits from `val` to `preimageBits`, the bit array for the preimage.
```javascript
for field bit in 0..32 do
for u32 bit in 0..32 do
preimageBits[i*32+bit] = val[bit]
endfor
endfor
@ -229,8 +205,8 @@ Proofs give us.
2. Compile `reveal_bit.zok` and create the proving and verification keys.
```
zokrates compile -i reveal_bit.zok -o reveal_bit --light
zokrates setup -i reveal_bit --light
zokrates compile -i reveal_bit.zok -o reveal_bit
zokrates setup -i reveal_bit
```
3. Copy the file `proving.key` to Alice's directory.
@ -238,13 +214,13 @@ Proofs give us.
4. Alice should compile `reveal_bit.zok` independently to make sure it doesn't disclose information she wants to keep secret.
```
zokrates compile -i reveal_bit.zok -o reveal_bit --light
zokrates compile -i reveal_bit.zok -o reveal_bit
```
5. Next, Alice creates the `witness` file with the values of all the parameters in the program. Using this `witness`,
Bob's `proving.key`, and the compiled program she generates the actual proof.
```
zokrates compute-witness -i reveal_bit -a 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 510 --light
zokrates compute-witness -i reveal_bit -a 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 510
zokrates generate-proof -i reveal_bit
```

View file

@ -7,7 +7,7 @@ In particular, we'll show how ZoKrates and the Ethereum blockchain can be used t
## Pre-requisites
Make sure you have followed the instructions in the [Getting Started](gettingstarted.md) chapter and are able to run the "Hello World" example described there.
Make sure you have followed the instructions in the [Getting Started](../gettingstarted.md) chapter and are able to run the "Hello World" example described there.
## Computing a Hash using ZoKrates

View file

@ -60,6 +60,8 @@ zokrates compute-witness -a 337 113569
zokrates generate-proof
# export a solidity verifier
zokrates export-verifier
# or verify natively
zokrates verify
```
The CLI commands are explained in more detail in the [CLI reference](toolbox/cli.md).

View file

@ -0,0 +1,17 @@
## Constants
Constants must be globally defined outside all other scopes by using a `const` keyword. Constants can be set only to a constant expression.
```zokrates
{{#include ../../../zokrates_cli/examples/book/constant_definition.zok}}
```
The value of a constant can't be changed through reassignment, and it can't be redeclared.
Constants must be explicitly typed. One can reference other constants inside the expression, as long as the referenced constant is already defined.
```zokrates
{{#include ../../../zokrates_cli/examples/book/constant_reference.zok}}
```
The naming convention for constants are similar to that of variables. All characters in a constant name are usually in uppercase.

View file

@ -26,6 +26,23 @@ An if-expression allows you to branch your code depending on a boolean condition
{{#include ../../../zokrates_cli/examples/book/if_else.zok}}
```
There are two important caveats when it comes to conditional expressions. Before we go into them, let's define two concepts:
- for an execution of the program, *an executed branch* is a branch which has to be paid for when executing the program, generating proofs, etc.
- for an execution of the program, *a logically executed branch* is a branch which is "chosen" by the condition of an if-expression. This is the more intuitive notion of execution, and there is only one for each if-expression.
Now the two caveats:
- **Both branches are always executed**. No short-circuiting happens based on the value of the condition. Therefore, the complexity of a program in terms of the number of constraints it compiles down to is the *sum* of the cost of all branches.
```zokrates
{{#include ../../../zokrates_cli/examples/book/if_else_expensive.zok}}
```
- **An unsatisfied constraint inside any branch will make the whole execution fail, even if this branch is not logically executed**. Also, the compiler itself inserts assertions which can fail. This can lead to unexpected results:
```zokrates
{{#include ../../../zokrates_cli/examples/book/if_else_panic.zok}}
```
The experimental flag `--branch-isolation` can be activated in the CLI in order to restrict any unsatisfied constraint to make the execution fail only if it is in a logically executed branch. This way, the execution of the program above will always succeed.
>The reason for these caveats is that the program is compiled down to an arithmetic circuit. This construct does not support jumping to a branch depending on a condition as you could do on traditional architectures. Instead, all branches are inlined as if they were printed on a circuit board. The `branch-isolation` feature comes with overhead for each assertion in each branch, and this overhead compounds when deeply nesting conditionals.
### For loops
For loops are available with the following syntax:

View file

@ -44,7 +44,7 @@ from "./path/to/my/module" import main as module
Note that this legacy method is likely to become deprecated, so it is recommended to use the preferred way instead.
### Symbols
Two types of symbols can be imported
Three types of symbols can be imported
#### Functions
Functions are imported by name. If many functions have the same name but different signatures, all of them get imported, and which one to use in a particular call is inferred.
@ -52,6 +52,9 @@ Functions are imported by name. If many functions have the same name but differe
#### User-defined types
User-defined types declared with the `struct` keyword are imported by name.
#### Constants
Constants declared with the `const` keyword are imported by name.
### Relative Imports
You can import a resource in the same folder directly, like this:
@ -66,4 +69,4 @@ from "../../../mycode" import foo
### Absolute Imports
Absolute imports don't start with `./` or `../` in the path and are used to import components from the ZoKrates standard library. Please check the according [section](./stdlib.html) for more details.
Absolute imports don't start with `./` or `../` in the path and are used to import components from the ZoKrates standard library. Please check the according [section](../toolbox/stdlib.md) for more details.

View file

@ -22,4 +22,4 @@ The following table lists the precedence and associativity of all operators. Ope
[^2]: The right operand must be a compile time constant of type `u32`
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`
[^3]: Both operands are asserted to be strictly lower than the biggest power of 2 lower than `p/2`, unless one of them can be determined to be a compile-time constant

View file

@ -8,7 +8,7 @@ ZoKrates currently exposes two primitive types and two complex types:
This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers.
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/toolbox/proving_schemes.html#alt_bn128) curve supported by Ethereum.
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](../toolbox/proving_schemes.md#curves) curve supported by Ethereum.
While `field` values mostly behave like unsigned integers, one should keep in mind that they overflow at `p` and not some power of 2, so that we have:
@ -48,7 +48,7 @@ ZoKrates provides two complex types: arrays and structs.
### Arrays
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. For more details on generic array sizes, see [constant generics](/language/generics.html)
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. For more details on generic array sizes, see [constant generics](../language/generics.md)
Arrays can contain elements of any type and have arbitrary dimensions.
The following example code shows examples of how to use arrays:

View file

@ -26,7 +26,7 @@ In this example, the ABI specification is:
"members":[
{
"name":"a",
"type":"field"
"type":"u8"
},
{
"name":"b",
@ -75,7 +75,7 @@ When executing a program, arguments can be passed as a JSON object of the follow
```json
[
{
"a":"42",
"a":"0x2a",
"b":{
"a":"42"
}
@ -89,5 +89,6 @@ When executing a program, arguments can be passed as a JSON object of the follow
```
Note the following:
- Field elements are passed as JSON strings in order to support arbitrary large numbers.
- Field elements are passed as JSON strings in order to support arbitrary large numbers
- Unsigned integers are passed as JSON strings containing their hexadecimal representation
- Structs are passed as JSON objects, ignoring the struct name

View file

@ -1,6 +1,6 @@
## Standard library
ZoKrates comes with a number of reusable components in the form of a Standard Library. In order to import it as described in the [imports](./imports.html) section, the `$ZOKRATES_HOME` environment variable must be set to the `stdlib` folder.
ZoKrates comes with a number of reusable components in the form of a Standard Library. In order to import it as described in the [imports](../language/imports.md) section, the `$ZOKRATES_STDLIB` environment variable must be set to the `stdlib` folder.
The full ZoKrates Standard Library can be found [here](https://github.com/Zokrates/ZoKrates/tree/latest/zokrates_stdlib/stdlib).

View file

@ -90,7 +90,19 @@ const artifacts = zokratesProvider.compile(source, options);
**Note:** The `resolveCallback` function is used to resolve dependencies.
This callback receives the current module location and the import location of the module which is being imported.
The callback must synchronously return either an error, `null` or a valid `ResolverResult` object like shown in the example above.
The callback must synchronously return either an error, `null` or a valid `ResolverResult` object like shown in the example above.
A simple file system resolver for a node environment can be implemented as follows:
```js
const fs = require("fs");
const path = require("path");
const fileSystemResolver = (from, to) => {
const location = path.resolve(path.dirname(path.resolve(from)), to);
const source = fs.readFileSync(location).toString();
return { source, location };
};
```
##### computeWitness(artifacts, args)
Computes a valid assignment of the variables, which include the results of the computation.

View file

@ -1,6 +1,6 @@
[package]
name = "zokrates_cli"
version = "0.7.0"
version = "0.7.4"
authors = ["Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>", "Dennis Kuhnert <mail@kyroy.com>", "Thibaut Schaeffer <thibaut@schaeff.fr>"]
repository = "https://github.com/JacobEberhardt/ZoKrates.git"
edition = "2018"

View file

@ -0,0 +1,5 @@
def identity<N>(field[N][N] t) -> field[N][N]:
return t
def main() -> field[1][1]:
return identity([[0]; 1])

View file

@ -0,0 +1,11 @@
def foo() -> u32:
return 0
def bar() -> (u32, u32):
return 0, 0
def main(u32[1] a, u32 b):
a[0] = foo()
a[0], b = bar()
return

View file

@ -3,7 +3,7 @@ struct Bar {
}
struct Foo {
field a
u8 a
Bar b
}

View file

@ -0,0 +1,4 @@
const field PRIME = 31
def main() -> field:
return PRIME

View file

@ -0,0 +1,5 @@
const field ONE = 1
const field TWO = ONE + ONE
def main() -> field:
return TWO

View file

@ -0,0 +1,12 @@
def cheap(field x) -> field:
return x + 1
def expensive(field x) -> field:
return x**1000
def main(field x) -> field:
return if x == 1 then\
cheap(x)\// executed
else\
expensive(x)\// also executed
fi

View file

@ -0,0 +1,6 @@
def main(field x) -> field:
return if x == 0 then\
0\
else\
1/x\// executed even for x := 0, which leads to the execution failing
fi

View file

@ -0,0 +1,5 @@
import "hashes/sha256/512bit" as sha256
def main(u32[16] hashMe) -> u32[8]:
u32[8] h = sha256(hashMe[0..8], hashMe[8..16])
return h

View file

@ -0,0 +1,20 @@
import "hashes/sha256/512bit" as sha256
import "utils/casts/u32_to_bits" as u32_to_bits
// Reveal a bit from a 512 bit value, and return it with the corresponding hash
// for that value.
//
// WARNING, once enough bits have been revealed it is possible to brute force
// the remaining preimage bits.
def main(private u32[16] preimage, u32 bitNum) -> (u32[8], bool):
// Convert the preimage to bits
bool[512] preimageBits = [false; 512]
for u32 i in 0..16 do
bool[32] val = u32_to_bits(preimage[i])
for u32 bit in 0..32 do
preimageBits[i*32+bit] = val[bit]
endfor
endfor
return sha256(preimage[0..8], preimage[8..16]), preimageBits[bitNum]

View file

@ -0,0 +1,4 @@
const field SIZE = 2
def main(field[SIZE] n):
return

View file

@ -0,0 +1,4 @@
const u8 SIZE = 0x02
def main(field[SIZE] n):
return

View file

@ -0,0 +1,5 @@
const field a = 1
def main() -> field:
a = 2 // not allowed
return a

View file

@ -0,0 +1,3 @@
def main():
field[2] a[2] = [1, 2] // only variables can be declared in such a statement, declaring `a[2]` is invalid
return

View file

@ -0,0 +1,7 @@
const u32 N = 42
def foo<N>(field[N] a) -> bool:
return true
def main():
return

View file

@ -0,0 +1,4 @@
def main():
for u32 i in 0..-1 do
endfor
return

View file

@ -0,0 +1,17 @@
struct Foo {
field[2] values
}
struct Bar {
Foo foo
field bar
}
def main():
Bar s = Bar {
foo: Foo { values: [1] },
bar: 0,
}
field b = s.bar
return

View file

@ -0,0 +1,17 @@
struct Foo {
u8[2] values
}
struct Bar {
Foo foo
u8 bar
}
def main():
Bar s = Bar {
foo: Foo { values: [1] }, // notice the size mismatch here
bar: 0,
}
u8 b = s.bar
return

View file

@ -0,0 +1,7 @@
const u32 N = 1
def foo(bool[N] arr) -> bool:
return true
def main(bool[N] arr):
assert(foo(arr))
return

View file

@ -4,7 +4,7 @@ def bound(field x) -> u32:
def main(field a) -> field:
field x = 7
x = x + 1
for u32 i in 0..bound(x) do
for u32 i in 0..bound(x) + bound(x + 1) do
// x = x + a
x = x + a
endfor

View file

@ -1,10 +1,10 @@
def const() -> field:
def constant() -> field:
return 123123
def add(field a,field b) -> field:
a=const()
return a+b
def add(field a, field b) -> field:
a = constant()
return a + b
def main(field a,field b) -> field:
field c = add(a, b+const())
return const()
def main(field a, field b) -> field:
field c = add(a, b + constant())
return constant()

View file

@ -1,5 +1,7 @@
struct Bar {
}
struct Bar {}
const field ONE = 1
const field BAR = 21 * ONE
def main() -> field:
return 21
return BAR

View file

@ -1,5 +1,6 @@
struct Baz {
}
struct Baz {}
const field BAZ = 123
def main() -> field:
return 123
return BAZ

View file

@ -1,9 +1,10 @@
from "./baz" import Baz
import "./baz"
from "./baz" import main as my_function
import "./baz"
const field FOO = 144
def main() -> field:
field a = my_function()
Baz b = Baz {}
return baz()
Baz b = Baz {}
assert(baz() == my_function())
return FOO

View file

@ -1,11 +0,0 @@
from "./bar" import Bar as MyBar
from "./bar" import Bar
import "./foo"
import "./bar"
def main() -> field:
MyBar my_bar = MyBar {}
Bar bar = Bar {}
assert(my_bar == bar)
return foo() + bar()

View file

@ -0,0 +1,6 @@
from "./foo" import FOO
from "./bar" import BAR
from "./baz" import BAZ
def main() -> bool:
return FOO == BAR + BAZ

View file

@ -0,0 +1,6 @@
import "./foo"
import "./bar"
import "./baz"
def main() -> bool:
return foo() == bar() + baz()

View file

@ -0,0 +1,8 @@
from "./bar" import Bar as MyBar
from "./bar" import Bar
def main():
MyBar my_bar = MyBar {}
Bar bar = Bar {}
assert(my_bar == bar)
return

View file

@ -1,4 +1,8 @@
import "./foo" as d
from "./bar" import main as bar
from "./baz" import BAZ as baz
import "./foo" as f
def main() -> field:
return d()
field foo = f()
assert(foo == bar() + baz)
return foo

View file

@ -9,4 +9,4 @@ def main(field a) -> bool:
// maxvalue = 2**252 - 1
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
// we added a = 0 to prevent the condition to be evaluated at compile time
return 0 < (maxvalue + 1)
return a < (maxvalue + 1)

View file

@ -4,4 +4,4 @@
def main(field a) -> bool:
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a
// we added a = 0 to prevent the condition to be evaluated at compile time
return 0 < p
return a < p

View file

@ -0,0 +1,11 @@
struct Foo {
field a
}
struct Bar {
Foo foo
}
def main(Bar b):
field a = b.foo.a
return

View file

@ -5,7 +5,7 @@ use std::convert::TryFrom;
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::{Path, PathBuf};
use zokrates_core::compile::{check, CompileError};
use zokrates_core::compile::{check, CompileConfig, CompileError};
use zokrates_field::{Bls12_377Field, Bls12_381Field, Bn128Field, Bw6_761Field, Field};
use zokrates_fs_resolver::FileSystemResolver;
@ -41,6 +41,11 @@ pub fn subcommand() -> App<'static, 'static> {
.possible_values(constants::CURVES)
.default_value(constants::BN128),
)
.arg(Arg::with_name("isolate-branches")
.long("isolate-branches")
.help("Isolate the execution of branches: a panic in a branch only makes the program panic if this branch is being logically executed")
.required(false)
)
}
pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> {
@ -84,8 +89,11 @@ fn cli_check<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
)),
}?;
let config =
CompileConfig::default().isolate_branches(sub_matches.is_present("isolate-branches"));
let resolver = FileSystemResolver::with_stdlib_root(stdlib_path);
let _ = check::<T, _>(source, path, Some(&resolver)).map_err(|e| {
let _ = check::<T, _>(source, path, Some(&resolver), &config).map_err(|e| {
format!(
"Check failed:\n\n{}",
e.0.iter()

View file

@ -56,6 +56,10 @@ pub fn subcommand() -> App<'static, 'static> {
.long("allow-unconstrained-variables")
.help("Allow unconstrained variables by inserting dummy constraints")
.required(false)
).arg(Arg::with_name("isolate-branches")
.long("isolate-branches")
.help("Isolate the execution of branches: a panic in a branch only makes the program panic if this branch is being logically executed")
.required(false)
).arg(Arg::with_name("ztf")
.long("ztf")
.help("Write human readable output (ztf)")
@ -122,9 +126,9 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
)),
}?;
let config = CompileConfig {
allow_unconstrained_variables: sub_matches.is_present("allow-unconstrained-variables"),
};
let config = CompileConfig::default()
.allow_unconstrained_variables(sub_matches.is_present("allow-unconstrained-variables"))
.isolate_branches(sub_matches.is_present("isolate-branches"));
let resolver = FileSystemResolver::with_stdlib_root(stdlib_path);
let artifacts: CompilationArtifacts<T> = compile(source, path, Some(&resolver), &config)

View file

@ -1,6 +1,6 @@
[package]
name = "zokrates_core"
version = "0.6.0"
version = "0.6.4"
edition = "2018"
authors = ["Jacob Eberhardt <jacob.eberhardt@tu-berlin.de>", "Dennis Kuhnert <mail@kyroy.com>"]
repository = "https://github.com/JacobEberhardt/ZoKrates"

View file

@ -1,60 +1,71 @@
use crate::absy;
use crate::imports;
use num_bigint::BigUint;
use std::path::Path;
use zokrates_pest_ast as pest;
impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
fn from(prog: pest::File<'ast>) -> absy::Module<'ast> {
absy::Module::with_symbols(
prog.structs
.into_iter()
.map(absy::SymbolDeclarationNode::from)
.chain(
prog.functions
.into_iter()
.map(absy::SymbolDeclarationNode::from),
),
)
.imports(
prog.imports
.into_iter()
.map(absy::ImportDirective::from)
.flatten(),
)
fn from(file: pest::File<'ast>) -> absy::Module<'ast> {
absy::Module::with_symbols(file.declarations.into_iter().flat_map(|d| match d {
pest::SymbolDeclaration::Import(i) => import_directive_to_symbol_vec(i),
pest::SymbolDeclaration::Constant(c) => vec![c.into()],
pest::SymbolDeclaration::Struct(s) => vec![s.into()],
pest::SymbolDeclaration::Function(f) => vec![f.into()],
}))
}
}
impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportDirective<'ast> {
fn from(import: pest::ImportDirective<'ast>) -> absy::ImportDirective<'ast> {
use crate::absy::NodeValue;
fn import_directive_to_symbol_vec(
import: pest::ImportDirective,
) -> Vec<absy::SymbolDeclarationNode> {
use crate::absy::NodeValue;
match import {
pest::ImportDirective::Main(import) => absy::ImportDirective::Main(
imports::Import::new(None, std::path::Path::new(import.source.span.as_str()))
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span),
),
pest::ImportDirective::From(import) => absy::ImportDirective::From(
import
.symbols
.iter()
.map(|symbol| {
imports::Import::new(
Some(symbol.symbol.span.as_str()),
std::path::Path::new(import.source.span.as_str()),
)
.alias(
symbol
.alias
.as_ref()
.map(|a| a.span.as_str())
.or_else(|| Some(symbol.symbol.span.as_str())),
)
.span(symbol.span.clone())
})
.collect(),
),
match import {
pest::ImportDirective::Main(import) => {
let span = import.span;
let source = Path::new(import.source.span.as_str());
let id = "main";
let alias = import.alias.map(|a| a.span.as_str());
let import = absy::CanonicalImport {
source,
id: absy::SymbolIdentifier::from(id).alias(alias),
}
.span(span.clone());
vec![absy::SymbolDeclaration {
id: alias.unwrap_or(id),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Import(import)),
}
.span(span.clone())]
}
pest::ImportDirective::From(import) => {
let span = import.span;
let source = Path::new(import.source.span.as_str());
import
.symbols
.into_iter()
.map(|symbol| {
let alias = symbol
.alias
.as_ref()
.map(|a| a.span.as_str())
.unwrap_or_else(|| symbol.id.span.as_str());
let import = absy::CanonicalImport {
source,
id: absy::SymbolIdentifier::from(symbol.id.span.as_str())
.alias(Some(alias)),
}
.span(span.clone());
absy::SymbolDeclaration {
id: alias,
symbol: absy::Symbol::Here(absy::SymbolDefinition::Import(import)),
}
.span(span.clone())
})
.collect()
}
}
}
@ -78,7 +89,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::HereType(ty),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Struct(ty)),
}
.span(span)
}
@ -98,8 +109,29 @@ impl<'ast> From<pest::StructField<'ast>> for absy::StructDefinitionFieldNode<'as
}
}
impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(function: pest::Function<'ast>) -> absy::SymbolDeclarationNode<'ast> {
impl<'ast> From<pest::ConstantDefinition<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(definition: pest::ConstantDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;
let span = definition.span;
let id = definition.id.span.as_str();
let ty = absy::ConstantDefinition {
ty: definition.ty.into(),
expression: definition.expression.into(),
}
.span(span.clone());
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::Here(absy::SymbolDefinition::Constant(ty)),
}
.span(span)
}
}
impl<'ast> From<pest::FunctionDefinition<'ast>> for absy::SymbolDeclarationNode<'ast> {
fn from(function: pest::FunctionDefinition<'ast>) -> absy::SymbolDeclarationNode<'ast> {
use crate::absy::NodeValue;
let span = function.span;
@ -148,7 +180,7 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
absy::SymbolDeclaration {
id,
symbol: absy::Symbol::HereFunction(function),
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(function)),
}
.span(span)
}
@ -207,54 +239,73 @@ fn statements_from_definition(definition: pest::DefinitionStatement) -> Vec<absy
let e: absy::ExpressionNode = absy::ExpressionNode::from(definition.expression);
let s = match e.value {
absy::Expression::FunctionCall(..) => absy::Statement::MultipleDefinition(
vec![absy::AssigneeNode::from(a.a.clone())],
e,
),
_ => absy::Statement::Definition(absy::AssigneeNode::from(a.a.clone()), e),
};
match a.ty {
Some(ty) => {
assert_eq!(a.a.accesses.len(), 0);
match a {
pest::TypedIdentifierOrAssignee::TypedIdentifier(i) => {
let declaration = absy::Statement::Declaration(
absy::Variable::new(
a.a.id.span.as_str(),
absy::UnresolvedTypeNode::from(ty),
i.identifier.span.as_str(),
absy::UnresolvedTypeNode::from(i.ty),
)
.span(a.a.id.span.clone()),
.span(i.identifier.span.clone()),
)
.span(definition.span.clone());
let s = match e.value {
absy::Expression::FunctionCall(..) => absy::Statement::MultipleDefinition(
vec![absy::AssigneeNode::from(i.identifier.clone())],
e,
),
_ => absy::Statement::Definition(
absy::AssigneeNode::from(i.identifier.clone()),
e,
),
};
vec![declaration, s.span(definition.span)]
}
None => {
// Assignment
pest::TypedIdentifierOrAssignee::Assignee(a) => {
let s = match e.value {
absy::Expression::FunctionCall(..) => absy::Statement::MultipleDefinition(
vec![absy::AssigneeNode::from(a)],
e,
),
_ => absy::Statement::Definition(absy::AssigneeNode::from(a), e),
};
vec![s.span(definition.span)]
}
}
}
_ => {
// Multidefinition
let declarations = lhs.clone().into_iter().filter(|i| i.ty.is_some()).map(|a| {
let ty = a.ty;
let a = a.a;
let declarations = lhs.clone().into_iter().filter_map(|i| match i {
pest::TypedIdentifierOrAssignee::TypedIdentifier(i) => {
let ty = i.ty;
let id = i.identifier;
assert_eq!(a.accesses.len(), 0);
absy::Statement::Declaration(
absy::Variable::new(
a.id.span.as_str(),
absy::UnresolvedTypeNode::from(ty.unwrap()),
Some(
absy::Statement::Declaration(
absy::Variable::new(
id.span.as_str(),
absy::UnresolvedTypeNode::from(ty),
)
.span(id.span),
)
.span(i.span),
)
.span(a.id.span),
)
.span(a.span)
}
_ => None,
});
let lhs = lhs
.into_iter()
.map(|i| absy::Assignee::Identifier(i.a.id.span.as_str()).span(i.a.id.span))
.map(|i| match i {
pest::TypedIdentifierOrAssignee::TypedIdentifier(i) => {
absy::Assignee::Identifier(i.identifier.span.as_str())
.span(i.identifier.span)
}
pest::TypedIdentifierOrAssignee::Assignee(a) => absy::AssigneeNode::from(a),
})
.collect();
let multi_def = absy::Statement::MultipleDefinition(
@ -754,7 +805,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -771,10 +822,9 @@ mod tests {
.outputs(vec![UnresolvedType::FieldElement.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
};
assert_eq!(absy::Module::from(ast), expected);
}
@ -786,7 +836,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -801,10 +851,9 @@ mod tests {
.outputs(vec![UnresolvedType::Boolean.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
};
assert_eq!(absy::Module::from(ast), expected);
}
@ -817,7 +866,7 @@ mod tests {
let expected: absy::Module = absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: &source[4..8],
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![
absy::Parameter::private(
@ -854,10 +903,9 @@ mod tests {
.outputs(vec![UnresolvedType::FieldElement.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
};
assert_eq!(absy::Module::from(ast), expected);
@ -871,7 +919,7 @@ mod tests {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![absy::Parameter::private(
absy::Variable::new("a", ty.clone().mock()).into(),
@ -887,10 +935,9 @@ mod tests {
signature: UnresolvedSignature::new().inputs(vec![ty.mock()]),
}
.into(),
),
)),
}
.into()],
imports: vec![],
}
}
@ -945,7 +992,7 @@ mod tests {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
symbol: absy::Symbol::HereFunction(
symbol: absy::Symbol::Here(absy::SymbolDefinition::Function(
absy::Function {
arguments: vec![],
statements: vec![absy::Statement::Return(
@ -958,10 +1005,9 @@ mod tests {
signature: UnresolvedSignature::new(),
}
.into(),
),
)),
}
.into()],
imports: vec![],
}
}
@ -1069,18 +1115,14 @@ mod tests {
// A `Definition` is generated and no `Declaration`s
let definition = pest::DefinitionStatement {
lhs: vec![pest::OptionallyTypedAssignee {
ty: None,
a: pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("a"),
span: span.clone(),
},
accesses: vec![],
lhs: vec![pest::TypedIdentifierOrAssignee::Assignee(pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("a"),
span: span.clone(),
},
accesses: vec![],
span: span.clone(),
}],
})],
expression: pest::Expression::Literal(pest::LiteralExpression::DecimalLiteral(
pest::DecimalLiteralExpression {
value: pest::DecimalNumber {
@ -1107,18 +1149,14 @@ mod tests {
// A MultiDef is generated
let definition = pest::DefinitionStatement {
lhs: vec![pest::OptionallyTypedAssignee {
ty: None,
a: pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("a"),
span: span.clone(),
},
accesses: vec![],
lhs: vec![pest::TypedIdentifierOrAssignee::Assignee(pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("a"),
span: span.clone(),
},
accesses: vec![],
span: span.clone(),
}],
})],
expression: pest::Expression::Postfix(pest::PostfixExpression {
id: pest::IdentifierExpression {
value: String::from("foo"),
@ -1153,32 +1191,24 @@ mod tests {
let definition = pest::DefinitionStatement {
lhs: vec![
pest::OptionallyTypedAssignee {
ty: Some(pest::Type::Basic(pest::BasicType::Field(pest::FieldType {
pest::TypedIdentifierOrAssignee::TypedIdentifier(pest::TypedIdentifier {
ty: pest::Type::Basic(pest::BasicType::Field(pest::FieldType {
span: span.clone(),
}))),
a: pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("a"),
span: span.clone(),
},
accesses: vec![],
})),
identifier: pest::IdentifierExpression {
value: String::from("a"),
span: span.clone(),
},
span: span.clone(),
},
pest::OptionallyTypedAssignee {
ty: None,
a: pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("b"),
span: span.clone(),
},
accesses: vec![],
}),
pest::TypedIdentifierOrAssignee::Assignee(pest::Assignee {
id: pest::IdentifierExpression {
value: String::from("b"),
span: span.clone(),
},
accesses: vec![],
span: span.clone(),
},
}),
],
expression: pest::Expression::Postfix(pest::PostfixExpression {
id: pest::IdentifierExpression {

View file

@ -18,8 +18,6 @@ pub use crate::absy::variable::{Variable, VariableNode};
use crate::embed::FlatEmbed;
use std::path::{Path, PathBuf};
use crate::imports::ImportDirective;
use crate::imports::ImportNode;
use std::fmt;
use num_bigint::BigUint;
@ -44,38 +42,134 @@ pub struct Program<'ast> {
pub main: OwnedModuleId,
}
/// A declaration of a `FunctionSymbol`, be it from an import or a function definition
#[derive(PartialEq, Clone, Debug)]
#[derive(Debug, PartialEq, Clone)]
pub struct SymbolIdentifier<'ast> {
pub id: Identifier<'ast>,
pub alias: Option<Identifier<'ast>>,
}
impl<'ast> From<Identifier<'ast>> for SymbolIdentifier<'ast> {
fn from(id: &'ast str) -> Self {
SymbolIdentifier { id, alias: None }
}
}
impl<'ast> SymbolIdentifier<'ast> {
pub fn alias(mut self, alias: Option<Identifier<'ast>>) -> Self {
self.alias = alias;
self
}
pub fn get_alias(&self) -> Identifier<'ast> {
self.alias.unwrap_or(self.id)
}
}
impl<'ast> fmt::Display for SymbolIdentifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}{}",
self.id,
self.alias.map(|a| format!(" as {}", a)).unwrap_or_default()
)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CanonicalImport<'ast> {
pub source: &'ast Path,
pub id: SymbolIdentifier<'ast>,
}
pub type CanonicalImportNode<'ast> = Node<CanonicalImport<'ast>>;
impl<'ast> fmt::Display for CanonicalImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "from \"{}\" import {}", self.source.display(), self.id)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolImport<'ast> {
pub module_id: OwnedModuleId,
pub symbol_id: Identifier<'ast>,
}
pub type SymbolImportNode<'ast> = Node<SymbolImport<'ast>>;
impl<'ast> SymbolImport<'ast> {
pub fn with_id_in_module<S: Into<Identifier<'ast>>, U: Into<OwnedModuleId>>(
symbol_id: S,
module_id: U,
) -> Self {
SymbolImport {
symbol_id: symbol_id.into(),
module_id: module_id.into(),
}
}
}
impl<'ast> fmt::Display for SymbolImport<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"from \"{}\" import {}",
self.module_id.display(),
self.symbol_id
)
}
}
/// A declaration of a symbol
#[derive(Debug, PartialEq, Clone)]
pub struct SymbolDeclaration<'ast> {
pub id: Identifier<'ast>,
pub symbol: Symbol<'ast>,
}
#[derive(PartialEq, Clone)]
#[allow(clippy::large_enum_variant)]
#[derive(Debug, PartialEq, Clone)]
pub enum SymbolDefinition<'ast> {
Import(CanonicalImportNode<'ast>),
Struct(StructDefinitionNode<'ast>),
Constant(ConstantDefinitionNode<'ast>),
Function(FunctionNode<'ast>),
}
#[derive(Debug, PartialEq, Clone)]
pub enum Symbol<'ast> {
HereType(StructDefinitionNode<'ast>),
HereFunction(FunctionNode<'ast>),
Here(SymbolDefinition<'ast>),
There(SymbolImportNode<'ast>),
Flat(FlatEmbed),
}
impl<'ast> fmt::Debug for Symbol<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Symbol::HereType(t) => write!(f, "HereType({:?})", t),
Symbol::HereFunction(fun) => write!(f, "HereFunction({:?})", fun),
Symbol::There(t) => write!(f, "There({:?})", t),
Symbol::Flat(flat) => write!(f, "Flat({:?})", flat),
}
}
}
impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.symbol {
Symbol::HereType(ref t) => write!(f, "struct {} {}", self.id, t),
Symbol::HereFunction(ref fun) => write!(f, "def {}{}", self.id, fun),
Symbol::There(ref import) => write!(f, "import {} as {}", import, self.id),
match &self.symbol {
Symbol::Here(ref symbol) => match symbol {
SymbolDefinition::Import(ref i) => write!(
f,
"from \"{}\" import {}",
i.value.source.display(),
i.value.id
),
SymbolDefinition::Struct(ref t) => write!(f, "struct {} {}", self.id, t),
SymbolDefinition::Constant(ref c) => write!(
f,
"const {} {} = {}",
c.value.ty, self.id, c.value.expression
),
SymbolDefinition::Function(ref func) => {
write!(f, "def {}{}", self.id, func)
}
},
Symbol::There(ref i) => write!(
f,
"from \"{}\" import {} as {}",
i.value.module_id.display(),
i.value.symbol_id,
self.id
),
Symbol::Flat(ref flat_fun) => {
write!(f, "def {}{}:\n\t// hidden", self.id, flat_fun.signature())
}
@ -86,25 +180,18 @@ impl<'ast> fmt::Display for SymbolDeclaration<'ast> {
pub type SymbolDeclarationNode<'ast> = Node<SymbolDeclaration<'ast>>;
/// A module as a collection of `FunctionDeclaration`s
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct Module<'ast> {
/// Symbols of the module
pub symbols: Declarations<'ast>,
pub imports: Vec<ImportNode<'ast>>, // we still use `imports` as they are not directly converted into `FunctionDeclaration`s after the importer is done, `imports` is empty
}
impl<'ast> Module<'ast> {
pub fn with_symbols<I: IntoIterator<Item = SymbolDeclarationNode<'ast>>>(i: I) -> Self {
Module {
symbols: i.into_iter().collect(),
imports: vec![],
}
}
pub fn imports<I: IntoIterator<Item = ImportNode<'ast>>>(mut self, i: I) -> Self {
self.imports = i.into_iter().collect();
self
}
}
pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
@ -146,82 +233,35 @@ impl<'ast> fmt::Display for StructDefinitionField<'ast> {
type StructDefinitionFieldNode<'ast> = Node<StructDefinitionField<'ast>>;
/// An import
#[derive(Debug, Clone, PartialEq)]
pub struct SymbolImport<'ast> {
/// the id of the symbol in the target module. Note: there may be many candidates as imports statements do not specify the signature. In that case they must all be functions however.
pub symbol_id: Identifier<'ast>,
/// the id of the module to import from
pub module_id: OwnedModuleId,
pub struct ConstantDefinition<'ast> {
pub ty: UnresolvedTypeNode<'ast>,
pub expression: ExpressionNode<'ast>,
}
type SymbolImportNode<'ast> = Node<SymbolImport<'ast>>;
pub type ConstantDefinitionNode<'ast> = Node<ConstantDefinition<'ast>>;
impl<'ast> SymbolImport<'ast> {
pub fn with_id_in_module<S: Into<Identifier<'ast>>, U: Into<OwnedModuleId>>(
symbol_id: S,
module_id: U,
) -> Self {
SymbolImport {
symbol_id: symbol_id.into(),
module_id: module_id.into(),
}
}
}
impl<'ast> fmt::Display for SymbolImport<'ast> {
impl<'ast> fmt::Display for ConstantDefinition<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{} from {}",
self.symbol_id,
self.module_id.display().to_string()
)
write!(f, "const {}({})", self.ty, self.expression)
}
}
impl<'ast> fmt::Display for Module<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut res = vec![];
res.extend(
self.imports
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>(),
);
res.extend(
self.symbols
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>(),
);
let res = self
.symbols
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<_>>();
write!(f, "{}", res.join("\n"))
}
}
impl<'ast> fmt::Debug for Module<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"module(\n\timports:\n\t\t{}\n\tsymbols:\n\t\t{}\n)",
self.imports
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
.join("\n\t\t"),
self.symbols
.iter()
.map(|x| format!("{:?}", x))
.collect::<Vec<_>>()
.join("\n\t\t")
)
}
}
pub type ConstantGenericNode<'ast> = Node<Identifier<'ast>>;
/// A function defined locally
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct Function<'ast> {
/// Arguments of the function
pub arguments: Vec<ParameterNode<'ast>>,
@ -265,23 +305,8 @@ impl<'ast> fmt::Display for Function<'ast> {
}
}
impl<'ast> fmt::Debug for Function<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Function(arguments: {:?}, ...):\n{}",
self.arguments,
self.statements
.iter()
.map(|x| format!("\t{:?}", x))
.collect::<Vec<_>>()
.join("\n")
)
}
}
/// Something that we can assign to
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Assignee<'ast> {
Identifier(Identifier<'ast>),
Select(Box<AssigneeNode<'ast>>, Box<RangeOrExpression<'ast>>),
@ -290,16 +315,6 @@ pub enum Assignee<'ast> {
pub type AssigneeNode<'ast> = Node<Assignee<'ast>>;
impl<'ast> fmt::Debug for Assignee<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Assignee::Identifier(ref s) => write!(f, "Identifier({:?})", s),
Assignee::Select(ref a, ref e) => write!(f, "Select({:?}[{:?}])", a, e),
Assignee::Member(ref s, ref m) => write!(f, "Member({:?}.{:?})", s, m),
}
}
}
impl<'ast> fmt::Display for Assignee<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -312,7 +327,7 @@ impl<'ast> fmt::Display for Assignee<'ast> {
/// A statement in a `Function`
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Statement<'ast> {
Return(ExpressionListNode<'ast>),
Declaration(VariableNode<'ast>),
@ -356,31 +371,8 @@ impl<'ast> fmt::Display for Statement<'ast> {
}
}
impl<'ast> fmt::Debug for Statement<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Statement::Return(ref expr) => write!(f, "Return({:?})", expr),
Statement::Declaration(ref var) => write!(f, "Declaration({:?})", var),
Statement::Definition(ref lhs, ref rhs) => {
write!(f, "Definition({:?}, {:?})", lhs, rhs)
}
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?;
for l in list {
writeln!(f, "\t\t{:?}", l)?;
}
write!(f, "\tendfor")
}
Statement::MultipleDefinition(ref lhs, ref rhs) => {
write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs)
}
}
}
}
/// An element of an inline array, can be a spread `...a` or an expression `a`
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum SpreadOrExpression<'ast> {
Spread(SpreadNode<'ast>),
Expression(ExpressionNode<'ast>),
@ -401,17 +393,8 @@ impl<'ast> fmt::Display for SpreadOrExpression<'ast> {
}
}
impl<'ast> fmt::Debug for SpreadOrExpression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
SpreadOrExpression::Spread(ref s) => write!(f, "{:?}", s),
SpreadOrExpression::Expression(ref e) => write!(f, "{:?}", e),
}
}
}
/// The index in an array selector. Can be a range or an expression.
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum RangeOrExpression<'ast> {
Range(RangeNode<'ast>),
Expression(ExpressionNode<'ast>),
@ -426,13 +409,10 @@ impl<'ast> fmt::Display for RangeOrExpression<'ast> {
}
}
impl<'ast> fmt::Debug for RangeOrExpression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
RangeOrExpression::Range(ref s) => write!(f, "{:?}", s),
RangeOrExpression::Expression(ref e) => write!(f, "{:?}", e),
}
}
/// A spread
#[derive(Debug, Clone, PartialEq)]
pub struct Spread<'ast> {
pub expression: ExpressionNode<'ast>,
}
pub type SpreadNode<'ast> = Node<Spread<'ast>>;
@ -443,20 +423,8 @@ impl<'ast> fmt::Display for Spread<'ast> {
}
}
impl<'ast> fmt::Debug for Spread<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Spread({:?})", self.expression)
}
}
/// A spread
#[derive(Clone, PartialEq)]
pub struct Spread<'ast> {
pub expression: ExpressionNode<'ast>,
}
/// A range
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub struct Range<'ast> {
pub from: Option<ExpressionNode<'ast>>,
pub to: Option<ExpressionNode<'ast>>,
@ -481,14 +449,8 @@ impl<'ast> fmt::Display for Range<'ast> {
}
}
impl<'ast> fmt::Debug for Range<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Range({:?}, {:?})", self.from, self.to)
}
}
/// An expression
#[derive(Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq)]
pub enum Expression<'ast> {
IntConstant(BigUint),
FieldConstant(BigUint),
@ -625,73 +587,8 @@ impl<'ast> fmt::Display for Expression<'ast> {
}
}
impl<'ast> fmt::Debug for Expression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Expression::U8Constant(ref i) => write!(f, "U8({:x})", i),
Expression::U16Constant(ref i) => write!(f, "U16({:x})", i),
Expression::U32Constant(ref i) => write!(f, "U32({:x})", i),
Expression::U64Constant(ref i) => write!(f, "U64({:x})", i),
Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i),
Expression::IntConstant(ref i) => write!(f, "Int({:?})", i),
Expression::Identifier(ref var) => write!(f, "Ide({})", var),
Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs),
Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
Expression::Mult(ref lhs, ref rhs) => write!(f, "Mult({:?}, {:?})", lhs, rhs),
Expression::Div(ref lhs, ref rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
Expression::Rem(ref lhs, ref rhs) => write!(f, "Rem({:?}, {:?})", lhs, rhs),
Expression::Pow(ref lhs, ref rhs) => write!(f, "Pow({:?}, {:?})", lhs, rhs),
Expression::Neg(ref e) => write!(f, "Neg({:?})", e),
Expression::Pos(ref e) => write!(f, "Pos({:?})", e),
Expression::BooleanConstant(b) => write!(f, "{}", b),
Expression::IfElse(ref condition, ref consequent, ref alternative) => write!(
f,
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
Expression::FunctionCall(ref g, ref i, ref p) => {
write!(f, "FunctionCall({:?}, {:?}, (", g, i)?;
f.debug_list().entries(p.iter()).finish()?;
write!(f, ")")
}
Expression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs),
Expression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs),
Expression::Eq(ref lhs, ref rhs) => write!(f, "Eq({:?}, {:?})", lhs, rhs),
Expression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs),
Expression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs),
Expression::And(ref lhs, ref rhs) => write!(f, "And({:?}, {:?})", lhs, rhs),
Expression::Not(ref exp) => write!(f, "Not({:?})", exp),
Expression::InlineArray(ref exprs) => {
write!(f, "InlineArray([")?;
f.debug_list().entries(exprs.iter()).finish()?;
write!(f, "]")
}
Expression::ArrayInitializer(ref e, ref count) => {
write!(f, "ArrayInitializer({:?}, {:?})", e, count)
}
Expression::InlineStruct(ref id, ref members) => {
write!(f, "InlineStruct({:?}, [", id)?;
f.debug_list().entries(members.iter()).finish()?;
write!(f, "]")
}
Expression::Select(ref array, ref index) => {
write!(f, "Select({:?}, {:?})", array, index)
}
Expression::Member(ref struc, ref id) => write!(f, "Member({:?}, {:?})", struc, id),
Expression::Or(ref lhs, ref rhs) => write!(f, "Or({:?}, {:?})", lhs, rhs),
Expression::BitXor(ref lhs, ref rhs) => write!(f, "BitXor({:?}, {:?})", lhs, rhs),
Expression::BitAnd(ref lhs, ref rhs) => write!(f, "BitAnd({:?}, {:?})", lhs, rhs),
Expression::BitOr(ref lhs, ref rhs) => write!(f, "BitOr({:?}, {:?})", lhs, rhs),
Expression::LeftShift(ref lhs, ref rhs) => write!(f, "LeftShift({:?}, {:?})", lhs, rhs),
Expression::RightShift(ref lhs, ref rhs) => {
write!(f, "RightShift({:?}, {:?})", lhs, rhs)
}
}
}
}
/// A list of expressions, used in return statements
#[derive(Clone, PartialEq, Default)]
#[derive(Debug, Clone, PartialEq, Default)]
pub struct ExpressionList<'ast> {
pub expressions: Vec<ExpressionNode<'ast>>,
}
@ -709,9 +606,3 @@ impl<'ast> fmt::Display for ExpressionList<'ast> {
write!(f, "")
}
}
impl<'ast> fmt::Debug for ExpressionList<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ExpressionList({:?})", self.expressions)
}
}

View file

@ -74,7 +74,6 @@ impl<V: NodeValue> From<V> for Node<V> {
use crate::absy::types::UnresolvedType;
use crate::absy::*;
use crate::imports::*;
impl<'ast> NodeValue for Expression<'ast> {}
impl<'ast> NodeValue for ExpressionList<'ast> {}
@ -84,12 +83,13 @@ impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
impl<'ast> NodeValue for UnresolvedType<'ast> {}
impl<'ast> NodeValue for StructDefinition<'ast> {}
impl<'ast> NodeValue for StructDefinitionField<'ast> {}
impl<'ast> NodeValue for ConstantDefinition<'ast> {}
impl<'ast> NodeValue for Function<'ast> {}
impl<'ast> NodeValue for Module<'ast> {}
impl<'ast> NodeValue for CanonicalImport<'ast> {}
impl<'ast> NodeValue for SymbolImport<'ast> {}
impl<'ast> NodeValue for Variable<'ast> {}
impl<'ast> NodeValue for Parameter<'ast> {}
impl<'ast> NodeValue for Import<'ast> {}
impl<'ast> NodeValue for Spread<'ast> {}
impl<'ast> NodeValue for Range<'ast> {}
impl<'ast> NodeValue for Identifier<'ast> {}

View file

@ -140,19 +140,43 @@ impl From<static_analysis::Error> for CompileErrorInner {
impl fmt::Display for CompileErrorInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
CompileErrorInner::ParserError(ref e) => write!(f, "{}", e),
CompileErrorInner::MacroError(ref e) => write!(f, "{}", e),
CompileErrorInner::SemanticError(ref e) => write!(f, "{}", e),
CompileErrorInner::ReadError(ref e) => write!(f, "{}", e),
CompileErrorInner::ImportError(ref e) => write!(f, "{}", e),
CompileErrorInner::AnalysisError(ref e) => write!(f, "{}", e),
CompileErrorInner::ParserError(ref e) => write!(f, "\n\t{}", e),
CompileErrorInner::MacroError(ref e) => write!(f, "\n\t{}", e),
CompileErrorInner::SemanticError(ref e) => {
let location = e
.pos()
.map(|p| format!("{}", p.0))
.unwrap_or_else(|| "".to_string());
write!(f, "{}\n\t{}", location, e.message())
}
CompileErrorInner::ReadError(ref e) => write!(f, "\n\t{}", e),
CompileErrorInner::ImportError(ref e) => {
let location = e
.pos()
.map(|p| format!("{}", p.0))
.unwrap_or_else(|| "".to_string());
write!(f, "{}\n\t{}", location, e.message())
}
CompileErrorInner::AnalysisError(ref e) => write!(f, "\n\t{}", e),
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct CompileConfig {
pub allow_unconstrained_variables: bool,
pub isolate_branches: bool,
}
impl CompileConfig {
pub fn allow_unconstrained_variables(mut self, flag: bool) -> Self {
self.allow_unconstrained_variables = flag;
self
}
pub fn isolate_branches(mut self, flag: bool) -> Self {
self.isolate_branches = flag;
self
}
}
type FilePath = PathBuf;
@ -165,7 +189,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
) -> Result<CompilationArtifacts<T>, CompileErrors> {
let arena = Arena::new();
let (typed_ast, abi) = check_with_arena(source, location, resolver, &arena)?;
let (typed_ast, abi) = check_with_arena(source, location, resolver, config, &arena)?;
// flatten input program
let program_flattened = Flattener::flatten(typed_ast, config);
@ -192,16 +216,18 @@ pub fn check<T: Field, E: Into<imports::Error>>(
source: String,
location: FilePath,
resolver: Option<&dyn Resolver<E>>,
config: &CompileConfig,
) -> Result<(), CompileErrors> {
let arena = Arena::new();
check_with_arena::<T, _>(source, location, resolver, &arena).map(|_| ())
check_with_arena::<T, _>(source, location, resolver, config, &arena).map(|_| ())
}
fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
source: String,
location: FilePath,
resolver: Option<&dyn Resolver<E>>,
config: &CompileConfig,
arena: &'ast Arena<String>,
) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> {
let source = arena.alloc(source);
@ -215,7 +241,7 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
// analyse (unroll and constant propagation)
typed_ast
.analyse()
.analyse(config)
.map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)]))
}
@ -283,7 +309,7 @@ mod test {
assert!(res.unwrap_err().0[0]
.value()
.to_string()
.contains(&"Can't resolve import without a resolver"));
.contains(&"Cannot resolve import without a resolver"));
}
#[test]

View file

@ -4,7 +4,8 @@ use crate::flat_absy::{
};
use crate::solvers::Solver;
use crate::typed_absy::types::{
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, GenericIdentifier,
ConcreteGenericsAssignment, DeclarationConstant, DeclarationSignature, DeclarationType,
GenericIdentifier,
};
use std::collections::HashMap;
use zokrates_field::{Bn128Field, Field};
@ -43,10 +44,12 @@ impl FlatEmbed {
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]),
FlatEmbed::Unpack => DeclarationSignature::new()
.generics(vec![Some(Constant::Generic(GenericIdentifier {
name: "N",
index: 0,
}))])
.generics(vec![Some(DeclarationConstant::Generic(
GenericIdentifier {
name: "N",
index: 0,
},
))])
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
@ -122,13 +125,12 @@ impl FlatEmbed {
.generics
.into_iter()
.map(|c| match c.unwrap() {
Constant::Generic(g) => g,
DeclarationConstant::Generic(g) => g,
_ => unreachable!(),
});
assert_eq!(gen.len(), assignment.0.len());
gen.map(|g| *assignment.0.get(&g).clone().unwrap() as u32)
.collect()
gen.map(|g| *assignment.0.get(&g).unwrap() as u32).collect()
}
pub fn id(&self) -> &'static str {

View file

@ -70,7 +70,11 @@ impl fmt::Display for FlatVariable {
impl fmt::Debug for FlatVariable {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "FlatVariable(id: {})", self.id)
match self.id {
0 => write!(f, "~one"),
i if i > 0 => write!(f, "_{}", i - 1),
i => write!(f, "~out_{}", -(i + 1)),
}
}
}

View file

@ -208,6 +208,12 @@ pub enum FlatExpression<T> {
Mult(Box<FlatExpression<T>>, Box<FlatExpression<T>>),
}
impl<T> From<T> for FlatExpression<T> {
fn from(other: T) -> Self {
Self::Number(other)
}
}
impl<T: Field> FlatExpression<T> {
pub fn apply_substitution(
self,

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ use std::fmt;
use std::io;
use std::path::{Path, PathBuf};
use crate::absy::types::UnresolvedType;
use typed_arena::Arena;
use zokrates_common::Resolver;
use zokrates_field::{Bn128Field, Field};
@ -32,6 +33,14 @@ impl Error {
}
}
pub fn pos(&self) -> &Option<(Position, Position)> {
&self.pos
}
pub fn message(&self) -> &str {
&self.message
}
fn with_pos(self, pos: Option<(Position, Position)>) -> Error {
Error { pos, ..self }
}
@ -56,94 +65,6 @@ impl From<io::Error> for Error {
}
}
#[derive(PartialEq, Clone)]
pub enum ImportDirective<'ast> {
Main(ImportNode<'ast>),
From(Vec<ImportNode<'ast>>),
}
impl<'ast> IntoIterator for ImportDirective<'ast> {
type Item = ImportNode<'ast>;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
let vec = match self {
ImportDirective::Main(v) => vec![v],
ImportDirective::From(v) => v,
};
vec.into_iter()
}
}
type ImportPath<'ast> = &'ast Path;
#[derive(PartialEq, Clone)]
pub struct Import<'ast> {
source: ImportPath<'ast>,
symbol: Option<Identifier<'ast>>,
alias: Option<Identifier<'ast>>,
}
pub type ImportNode<'ast> = Node<Import<'ast>>;
impl<'ast> Import<'ast> {
pub fn new(symbol: Option<Identifier<'ast>>, source: ImportPath<'ast>) -> Import<'ast> {
Import {
symbol,
source,
alias: None,
}
}
pub fn get_alias(&self) -> &Option<Identifier<'ast>> {
&self.alias
}
pub fn new_with_alias(
symbol: Option<Identifier<'ast>>,
source: ImportPath<'ast>,
alias: Identifier<'ast>,
) -> Import<'ast> {
Import {
symbol,
source,
alias: Some(alias),
}
}
pub fn alias(mut self, alias: Option<Identifier<'ast>>) -> Self {
self.alias = alias;
self
}
pub fn get_source(&self) -> &ImportPath<'ast> {
&self.source
}
}
impl<'ast> fmt::Display for Import<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.alias {
Some(ref alias) => write!(f, "import {} as {}", self.source.display(), alias),
None => write!(f, "import {}", self.source.display()),
}
}
}
impl<'ast> fmt::Debug for Import<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.alias {
Some(ref alias) => write!(
f,
"import(source: {}, alias: {})",
self.source.display(),
alias
),
None => write!(f, "import(source: {})", self.source.display()),
}
}
}
pub struct Importer;
impl Importer {
@ -154,255 +75,168 @@ impl Importer {
modules: &mut HashMap<OwnedModuleId, Module<'ast>>,
arena: &'ast Arena<String>,
) -> Result<Module<'ast>, CompileErrors> {
let mut symbols: Vec<_> = vec![];
let symbols: Vec<_> = destination
.symbols
.into_iter()
.map(|s| match s.value.symbol {
Symbol::Here(SymbolDefinition::Import(import)) => {
Importer::resolve::<T, E>(import, &location, resolver, modules, arena)
}
_ => Ok(s),
})
.collect::<Result<_, _>>()?;
for import in destination.imports {
let pos = import.pos();
let import = import.value;
let alias = import.alias;
// handle the case of special bellman and packing imports
if import.source.starts_with("EMBED") {
match import.source.to_str().unwrap() {
#[cfg(feature = "bellman")]
"EMBED/sha256round" => {
if T::id() != Bn128Field::id() {
return Err(CompileErrorInner::ImportError(
Error::new(format!(
"Embed sha256round cannot be used with curve {}",
T::name()
))
.with_pos(Some(pos)),
)
.in_file(&location)
.into());
} else {
let alias = alias.unwrap_or("sha256round");
Ok(Module::with_symbols(symbols))
}
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::Sha256Round),
}
.start_end(pos.0, pos.1),
);
}
}
"EMBED/unpack" => {
let alias = alias.unwrap_or("unpack");
fn resolve<'ast, T: Field, E: Into<Error>>(
import: CanonicalImportNode<'ast>,
location: &Path,
resolver: Option<&dyn Resolver<E>>,
modules: &mut HashMap<OwnedModuleId, Module<'ast>>,
arena: &'ast Arena<String>,
) -> Result<SymbolDeclarationNode<'ast>, CompileErrors> {
let pos = import.pos();
let module_id = import.value.source;
let symbol = import.value.id;
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::Unpack),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u64_to_bits" => {
let alias = alias.unwrap_or("u64_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U64ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u32_to_bits" => {
let alias = alias.unwrap_or("u32_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U32ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u16_to_bits" => {
let alias = alias.unwrap_or("u16_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U16ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u8_to_bits" => {
let alias = alias.unwrap_or("u8_to_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U8ToBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u64_from_bits" => {
let alias = alias.unwrap_or("u64_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U64FromBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u32_from_bits" => {
let alias = alias.unwrap_or("u32_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U32FromBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u16_from_bits" => {
let alias = alias.unwrap_or("u16_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U16FromBits),
}
.start_end(pos.0, pos.1),
);
}
"EMBED/u8_from_bits" => {
let alias = alias.unwrap_or("u8_from_bits");
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::U8FromBits),
}
.start_end(pos.0, pos.1),
);
}
s => {
let symbol_declaration = match module_id.to_str().unwrap() {
"EMBED" => match symbol.id {
#[cfg(feature = "bellman")]
"sha256round" => {
if T::id() != Bn128Field::id() {
return Err(CompileErrorInner::ImportError(
Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)),
Error::new(format!(
"Embed sha256round cannot be used with curve {}",
T::name()
))
.with_pos(Some(pos)),
)
.in_file(&location)
.in_file(location)
.into());
} else {
SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::Sha256Round),
}
}
}
} else {
// to resolve imports, we need a resolver
match resolver {
Some(res) => match res.resolve(location.clone(), import.source.to_path_buf()) {
Ok((source, new_location)) => {
// generate an alias from the imported path if none was given explicitely
let alias = import.alias.unwrap_or(
std::path::Path::new(import.source)
.file_stem()
.ok_or_else(|| {
CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
import.source.display()
)))
.in_file(&location),
)
})?
.to_str()
.unwrap(),
);
match modules.get(&new_location) {
Some(_) => {}
None => {
let source = arena.alloc(source);
let compiled = compile_module::<T, E>(
source,
new_location.clone(),
resolver,
modules,
&arena,
)?;
assert!(modules
.insert(new_location.clone(), compiled)
.is_none());
}
};
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::There(
SymbolImport::with_id_in_module(
import.symbol.unwrap_or("main"),
new_location.display().to_string(),
)
.start_end(pos.0, pos.1),
),
}
.start_end(pos.0, pos.1),
);
"unpack" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::Unpack),
},
"u64_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U64ToBits),
},
"u32_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U32ToBits),
},
"u16_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U16ToBits),
},
"u8_to_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U8ToBits),
},
"u64_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U64FromBits),
},
"u32_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U32FromBits),
},
"u16_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U16FromBits),
},
"u8_from_bits" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::U8FromBits),
},
"FIELD_SIZE_IN_BITS" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Here(SymbolDefinition::Constant(
ConstantDefinition {
ty: UnresolvedType::Uint(32).into(),
expression: Expression::U32Constant(T::get_required_bits() as u32)
.into(),
}
Err(err) => {
return Err(CompileErrorInner::ImportError(
err.into().with_pos(Some(pos)),
)
.in_file(&location)
.into());
}
},
None => {
return Err(CompileErrorInner::from(Error::new(
"Can't resolve import without a resolver",
))
.in_file(&location)
.into());
}
.start_end(pos.0, pos.1),
)),
},
s => {
return Err(CompileErrorInner::ImportError(
Error::new(format!("Embed {} not found", s)).with_pos(Some(pos)),
)
.in_file(location)
.into());
}
}
}
},
_ => match resolver {
Some(res) => match res.resolve(location.to_path_buf(), module_id.to_path_buf()) {
Ok((source, new_location)) => {
let alias = symbol.alias.unwrap_or(
module_id
.file_stem()
.ok_or_else(|| {
CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
module_id.display()
)))
.in_file(location),
)
})?
.to_str()
.unwrap(),
);
symbols.extend(destination.symbols);
match modules.get(&new_location) {
Some(_) => {}
None => {
let source = arena.alloc(source);
let compiled = compile_module::<T, E>(
source,
new_location.clone(),
resolver,
modules,
&arena,
)?;
Ok(Module {
imports: vec![],
symbols,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn create_with_no_alias() {
assert_eq!(
Import::new(None, Path::new("./foo/bar/baz.zok")),
Import {
symbol: None,
source: Path::new("./foo/bar/baz.zok"),
alias: None,
}
);
}
#[test]
fn create_with_alias() {
assert_eq!(
Import::new_with_alias(None, Path::new("./foo/bar/baz.zok"), &"myalias"),
Import {
symbol: None,
source: Path::new("./foo/bar/baz.zok"),
alias: Some("myalias"),
}
);
assert!(modules.insert(new_location.clone(), compiled).is_none());
}
};
SymbolDeclaration {
id: &alias,
symbol: Symbol::There(
SymbolImport::with_id_in_module(symbol.id, new_location)
.start_end(pos.0, pos.1),
),
}
}
Err(err) => {
return Err(
CompileErrorInner::ImportError(err.into().with_pos(Some(pos)))
.in_file(location)
.into(),
);
}
},
None => {
return Err(CompileErrorInner::from(Error::new(
"Cannot resolve import without a resolver",
))
.in_file(location)
.into());
}
},
};
Ok(symbol_declaration.start_end(pos.0, pos.1))
}
}

View file

@ -259,7 +259,7 @@ impl<T: Field> LinComb<T> {
}
fn is_assignee<U>(&self, witness: &BTreeMap<FlatVariable, U>) -> bool {
self.0.iter().count() == 1
self.0.len() == 1
&& self.0.get(0).unwrap().1 == T::from(1)
&& !witness.contains_key(&self.0.get(0).unwrap().0)
}

File diff suppressed because it is too large Load diff

View file

@ -10,29 +10,6 @@ impl BoundsChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
BoundsChecker.fold_program(p)
}
pub fn check_select<'ast, T: Field, U: Select<'ast, T>>(
&mut self,
array: ArrayExpression<'ast, T>,
index: UExpression<'ast, T>,
) -> Result<U, Error> {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
match (array.get_array_type().size.as_inner(), index.as_inner()) {
(UExpressionInner::Value(size), UExpressionInner::Value(index)) => {
if index >= size {
return Err(format!(
"Out of bounds access: {}[{}] but {} is of size {}",
array, index, array, size
));
}
}
_ => unreachable!(),
};
Ok(U::select(array, index))
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
@ -44,19 +21,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::Select(box array, box index) => self
.check_select::<_, ArrayExpression<_>>(array, index)
.map(|a| a.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = self.fold_array_expression(array)?;
let from = self.fold_uint_expression(from)?;
let to = self.fold_uint_expression(to)?;
match (
array.get_array_type().size.as_inner(),
from.as_inner(),
to.as_inner(),
) {
match (array.ty().size.as_inner(), from.as_inner(), to.as_inner()) {
(
UExpressionInner::Value(size),
UExpressionInner::Value(from),
@ -86,49 +56,30 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
}
}
fn fold_struct_expression_inner(
fn fold_select_expression<
E: Expr<'ast, T> + Select<'ast, T> + From<TypedExpression<'ast, T>>,
>(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
match e {
StructExpressionInner::Select(box array, box index) => self
.check_select::<_, StructExpression<_>>(array, index)
.map(|a| a.into_inner()),
e => fold_struct_expression_inner(self, ty, e),
}
}
_: &E::Ty,
select: SelectExpression<'ast, T, E>,
) -> Result<SelectOrExpression<'ast, T, E>, Self::Error> {
let array = self.fold_array_expression(*select.array)?;
let index = self.fold_uint_expression(*select.index)?;
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::Select(box array, box index) => self.check_select(array, index),
e => fold_field_expression(self, e),
}
}
match (array.ty().size.as_inner(), index.as_inner()) {
(UExpressionInner::Value(size), UExpressionInner::Value(index)) => {
if index >= size {
return Err(format!(
"Out of bounds access: {}[{}] but {} is of size {}",
array, index, array, size
));
}
}
_ => unreachable!(),
};
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::Select(box array, box index) => self.check_select(array, index),
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
match e {
UExpressionInner::Select(box array, box index) => self
.check_select::<_, UExpression<_>>(array, index)
.map(|a| a.into_inner()),
e => fold_uint_expression_inner(self, bitwidth, e),
}
Ok(SelectOrExpression::Select(SelectExpression::new(
array, index,
)))
}
}

View file

@ -0,0 +1,33 @@
// Isolate branches means making sure that any branch is enclosed in a block.
// This is important, because we want any statement resulting from inlining any branch to be isolated from the coller, so that its panics can be conditional to the branch being logically run
// `if c then a else b fi` becomes `if c then { a } else { b } fi`, and down the line any statements resulting from trating `a` and `b` can be safely kept inside the respective blocks.
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use zokrates_field::Field;
pub struct Isolator;
impl Isolator {
pub fn isolate<T: Field>(p: TypedProgram<T>) -> TypedProgram<T> {
let mut isolator = Isolator;
isolator.fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for Isolator {
fn fold_if_else_expression<
E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + IfElse<'ast, T>,
>(
&mut self,
_: &E::Ty,
e: IfElseExpression<'ast, T, E>,
) -> IfElseOrExpression<'ast, T, E> {
IfElseOrExpression::IfElse(IfElseExpression::new(
self.fold_boolean_expression(*e.condition),
E::block(vec![], e.consequence.fold(self)),
E::block(vec![], e.alternative.fold(self)),
))
}
}

View file

@ -0,0 +1,898 @@
use crate::static_analysis::Propagator;
use crate::typed_absy::folder::*;
use crate::typed_absy::result_folder::ResultFolder;
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::*;
use std::collections::HashMap;
use std::convert::TryInto;
use zokrates_field::Field;
type ProgramConstants<'ast, T> =
HashMap<OwnedTypedModuleId, HashMap<Identifier<'ast>, TypedExpression<'ast, T>>>;
pub struct ConstantInliner<'ast, T> {
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ProgramConstants<'ast, T>,
}
impl<'ast, 'a, T: Field> ConstantInliner<'ast, T> {
pub fn new(
modules: TypedModules<'ast, T>,
location: OwnedTypedModuleId,
constants: ProgramConstants<'ast, T>,
) -> Self {
ConstantInliner {
modules,
location,
constants,
}
}
pub fn inline(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
let constants = ProgramConstants::new();
let mut inliner = ConstantInliner::new(p.modules.clone(), p.main.clone(), constants);
inliner.fold_program(p)
}
fn change_location(&mut self, location: OwnedTypedModuleId) -> OwnedTypedModuleId {
let prev = self.location.clone();
self.location = location;
self.constants.entry(self.location.clone()).or_default();
prev
}
fn treated(&self, id: &TypedModuleId) -> bool {
self.constants.contains_key(id)
}
fn get_constant(
&self,
id: &CanonicalConstantIdentifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&id.module)
.and_then(|constants| constants.get(&id.id.into()))
.cloned()
}
fn get_constant_for_identifier(
&self,
id: &Identifier<'ast>,
) -> Option<TypedExpression<'ast, T>> {
self.constants
.get(&self.location)
.and_then(|constants| constants.get(&id))
.cloned()
}
}
impl<'ast, T: Field> Folder<'ast, T> for ConstantInliner<'ast, T> {
fn fold_module_id(&mut self, id: OwnedTypedModuleId) -> OwnedTypedModuleId {
// anytime we encounter a module id, visit the corresponding module if it hasn't been done yet
if !self.treated(&id) {
let current_m_id = self.change_location(id.clone());
let m = self.modules.remove(&id).unwrap();
let m = self.fold_module(m);
self.modules.insert(id.clone(), m);
self.change_location(current_m_id);
}
id
}
fn fold_module(&mut self, m: TypedModule<'ast, T>) -> TypedModule<'ast, T> {
TypedModule {
constants: m
.constants
.into_iter()
.map(|(id, tc)| {
let constant = match tc {
TypedConstantSymbol::There(imported_id) => {
// visit the imported symbol. This triggers visiting the corresponding module if needed
let imported_id = self.fold_canonical_constant_identifier(imported_id);
// after that, the constant must have been defined defined in the global map. It is already reduced
// to a literal, so running propagation isn't required
self.get_constant(&imported_id).unwrap()
}
TypedConstantSymbol::Here(c) => {
let non_propagated_constant = fold_constant(self, c).expression;
// folding the constant above only reduces it to an expression containing only literals, not to a single literal.
// propagating with an empty map of constants reduces it to a single literal
Propagator::with_constants(&mut HashMap::default())
.fold_expression(non_propagated_constant)
.unwrap()
}
};
// add to the constant map. The value added is always a single litteral
self.constants
.get_mut(&self.location)
.unwrap()
.insert(id.id.into(), constant.clone());
(
id,
TypedConstantSymbol::Here(TypedConstant {
ty: constant.get_type().clone(),
expression: constant,
}),
)
})
.collect(),
functions: m
.functions
.into_iter()
.map(|(key, fun)| {
(
self.fold_declaration_function_key(key),
self.fold_function_symbol(fun),
)
})
.collect(),
}
}
fn fold_declaration_constant(
&mut self,
c: DeclarationConstant<'ast>,
) -> DeclarationConstant<'ast> {
match c {
// replace constants by their concrete value in declaration types
DeclarationConstant::Constant(id) => {
DeclarationConstant::Concrete(match self.get_constant(&id).unwrap() {
TypedExpression::Uint(UExpression {
inner: UExpressionInner::Value(v),
..
}) => v as u32,
_ => unreachable!("all constants found in declaration types should be reduceable to u32 literals"),
})
}
c => c,
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => c.try_into().unwrap(),
None => fold_field_expression(self, e),
}
}
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => c.try_into().unwrap(),
None => fold_boolean_expression(self, e),
},
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
size: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id) {
Some(c) => {
let e: UExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_uint_expression_inner(self, size, e),
},
e => fold_uint_expression_inner(self, size, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Identifier(ref id) => {
match self.get_constant_for_identifier(id) {
Some(c) => {
let e: ArrayExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_array_expression_inner(self, ty, e),
}
}
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Identifier(ref id) => match self.get_constant_for_identifier(id)
{
Some(c) => {
let e: StructExpression<'ast, T> = c.try_into().unwrap();
e.into_inner()
}
None => fold_struct_expression_inner(self, ty, e),
},
e => fold_struct_expression_inner(self, ty, e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
DeclarationFunctionKey, DeclarationType, FieldElementExpression, GType, Identifier,
TypedConstant, TypedExpression, TypedFunction, TypedFunctionSymbol, TypedStatement,
};
use zokrates_field::Bn128Field;
#[test]
fn inline_const_field() {
// const field a = 1
//
// def main() -> field:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(Bn128Field::from(1))),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(1)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_const_boolean() {
// const bool a = true
//
// def main() -> bool:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![BooleanExpression::Identifier(
Identifier::from(const_id),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Boolean,
TypedExpression::Boolean(BooleanExpression::Value(true)),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
BooleanExpression::Value(true).into()
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Boolean]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_const_uint() {
// const u32 a = 0x00000001
//
// def main() -> u32:
// return a
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Identifier(
Identifier::from(const_id),
)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::Uint(UBitwidth::B32),
UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into(),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![UExpressionInner::Value(1u128)
.annotate(UBitwidth::B32)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::Uint(UBitwidth::B32)]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_const_field_array() {
// const field[2] a = [2, 2]
//
// def main() -> field:
// return a[0] + a[1]
let const_id = "a";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::select(
ArrayExpressionInner::Identifier(Identifier::from(const_id))
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let constants: TypedConstantSymbols<_> = vec![(
CanonicalConstantIdentifier::new(const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::Array(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
),
)),
)]
.into_iter()
.collect();
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: constants.clone(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![FieldElementExpression::Add(
FieldElementExpression::select(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(0u128).annotate(UBitwidth::B32),
)
.into(),
FieldElementExpression::select(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
]
.into(),
)
.annotate(GType::FieldElement, 2usize),
UExpressionInner::Value(1u128).annotate(UBitwidth::B32),
)
.into(),
)
.into()])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants,
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_nested_const_field() {
// const field a = 1
// const field b = a + 1
//
// def main() -> field:
// return b
let const_a_id = "a";
let const_b_id = "b";
let main: TypedFunction<Bn128Field> = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(const_b_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(main),
)]
.into_iter()
.collect(),
constants: vec![
(
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
)),
),
(
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from(
const_a_id,
)),
box FieldElementExpression::Number(Bn128Field::from(1)),
)),
)),
),
]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main = TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(2)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(expected_main),
)]
.into_iter()
.collect(),
constants: vec![
(
CanonicalConstantIdentifier::new(const_a_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(1),
)),
)),
),
(
CanonicalConstantIdentifier::new(const_b_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(2),
)),
)),
),
]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
#[test]
fn inline_imported_constant() {
// ---------------------
// module `foo`
// --------------------
// const field FOO = 42
//
// def main():
// return
//
// ---------------------
// module `main`
// ---------------------
// from "foo" import FOO
//
// def main() -> field:
// return FOO
let foo_const_id = "FOO";
let foo_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main")
.signature(DeclarationSignature::new().inputs(vec![]).outputs(vec![])),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![],
signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "foo".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
)),
)]
.into_iter()
.collect(),
};
let main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Identifier(Identifier::from(foo_const_id)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::There(CanonicalConstantIdentifier::new(
foo_const_id,
"foo".into(),
)),
)]
.into_iter()
.collect(),
};
let program = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), main_module),
("foo".into(), foo_module.clone()),
]
.into_iter()
.collect(),
};
let program = ConstantInliner::inline(program);
let expected_main_module = TypedModule {
functions: vec![(
DeclarationFunctionKey::with_location("main", "main").signature(
DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
statements: vec![TypedStatement::Return(vec![
FieldElementExpression::Number(Bn128Field::from(42)).into(),
])],
signature: DeclarationSignature::new()
.inputs(vec![])
.outputs(vec![DeclarationType::FieldElement]),
}),
)]
.into_iter()
.collect(),
constants: vec![(
CanonicalConstantIdentifier::new(foo_const_id, "main".into()),
TypedConstantSymbol::Here(TypedConstant::new(
GType::FieldElement,
TypedExpression::FieldElement(FieldElementExpression::Number(
Bn128Field::from(42),
)),
)),
)]
.into_iter()
.collect(),
};
let expected_program: TypedProgram<Bn128Field> = TypedProgram {
main: "main".into(),
modules: vec![
("main".into(), expected_main_module),
("foo".into(), foo_module),
]
.into_iter()
.collect(),
};
assert_eq!(program, expected_program)
}
}

File diff suppressed because it is too large Load diff

View file

@ -5,10 +5,11 @@
//! @date 2018
mod bounds_checker;
mod branch_isolator;
mod constant_inliner;
mod flat_propagation;
mod flatten_complex_types;
mod propagation;
mod redefinition;
mod reducer;
mod shift_checker;
mod uint_optimizer;
@ -17,17 +18,19 @@ mod variable_read_remover;
mod variable_write_remover;
use self::bounds_checker::BoundsChecker;
use self::branch_isolator::Isolator;
use self::flatten_complex_types::Flattener;
use self::propagation::Propagator;
use self::redefinition::RedefinitionOptimizer;
use self::reducer::reduce_program;
use self::shift_checker::ShiftChecker;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_read_remover::VariableReadRemover;
use self::variable_write_remover::VariableWriteRemover;
use crate::compile::CompileConfig;
use crate::flat_absy::FlatProg;
use crate::ir::Prog;
use crate::static_analysis::constant_inliner::ConstantInliner;
use crate::typed_absy::{abi::Abi, TypedProgram};
use crate::zir::ZirProgram;
use std::fmt;
@ -72,15 +75,23 @@ impl fmt::Display for Error {
}
impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
let r = reduce_program(self).map_err(Error::from)?;
pub fn analyse(self, config: &CompileConfig) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
// inline user-defined constants
let r = ConstantInliner::inline(self);
// isolate branches
let r = if config.isolate_branches {
Isolator::isolate(r)
} else {
r
};
// reduce the program to a single function
let r = reduce_program(r).map_err(Error::from)?;
// generate abi
let abi = r.abi();
// propagate
let r = Propagator::propagate(r).map_err(Error::from)?;
// optimize redefinitions
let r = RedefinitionOptimizer::optimize(r);
// remove assignment to variable index
let r = VariableWriteRemover::apply(r);
// remove variable access to complex types

View file

@ -152,13 +152,16 @@ fn is_constant<T: Field>(e: &TypedExpression<T>) -> bool {
}
}
fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
fn remove_spreads_aux<T: Field>(e: TypedExpressionOrSpread<T>) -> Vec<TypedExpression<T>> {
// in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc
fn to_canonical_constant<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
fn to_canonical_constant_aux<T: Field>(
e: TypedExpressionOrSpread<T>,
) -> Vec<TypedExpression<T>> {
match e {
TypedExpressionOrSpread::Expression(e) => vec![e],
TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() {
ArrayExpressionInner::Value(v) => {
v.into_iter().flat_map(remove_spreads_aux).collect()
v.into_iter().flat_map(to_canonical_constant_aux).collect()
}
_ => unimplemented!(),
},
@ -167,12 +170,12 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
match e {
TypedExpression::Array(a) => {
let array_ty = a.get_array_type();
let array_ty = a.ty();
match a.into_inner() {
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
v.into_iter()
.flat_map(remove_spreads_aux)
.flat_map(to_canonical_constant_aux)
.map(|e| e.into())
.collect::<Vec<_>>()
.into(),
@ -197,7 +200,7 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
ArrayExpressionInner::Value(
v.into_iter()
.flat_map(remove_spreads_aux)
.flat_map(to_canonical_constant_aux)
.map(|e| e.into())
.enumerate()
.filter(|(index, _)| index >= &from && index < &to)
@ -214,7 +217,7 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
_ => unreachable!("should be a uint value"),
};
let e = remove_spreads(e);
let e = to_canonical_constant(e);
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(e); count].into(),
@ -225,6 +228,18 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
_ => unreachable!(),
}
}
TypedExpression::Struct(s) => {
let struct_ty = s.ty().clone();
match s.into_inner() {
StructExpressionInner::Value(expressions) => StructExpressionInner::Value(
expressions.into_iter().map(to_canonical_constant).collect(),
)
.annotate(struct_ty)
.into(),
_ => unreachable!(),
}
}
e => e,
}
}
@ -264,6 +279,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
})
.collect::<Result<_, _>>()?,
..m
})
}
@ -274,6 +290,35 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
fold_function(self, f)
}
fn fold_if_else_expression<
E: Expr<'ast, T> + IfElse<'ast, T> + PartialEq + ResultFold<'ast, T>,
>(
&mut self,
_: &E::Ty,
e: IfElseExpression<'ast, T, E>,
) -> Result<IfElseOrExpression<'ast, T, E>, Self::Error> {
Ok(
match (
self.fold_boolean_expression(*e.condition)?,
e.consequence.fold(self)?,
e.alternative.fold(self)?,
) {
(BooleanExpression::Value(true), consequence, _) => {
IfElseOrExpression::Expression(consequence.into_inner())
}
(BooleanExpression::Value(false), _, alternative) => {
IfElseOrExpression::Expression(alternative.into_inner())
}
(_, consequence, alternative) if consequence == alternative => {
IfElseOrExpression::Expression(consequence.into_inner())
}
(condition, consequence, alternative) => IfElseOrExpression::IfElse(
IfElseExpression::new(condition, consequence, alternative),
),
},
)
}
fn fold_statement(
&mut self,
s: TypedStatement<'ast, T>,
@ -299,7 +344,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
if is_constant(&expr) {
match assignee {
TypedAssignee::Identifier(var) => {
let expr = remove_spreads(expr);
let expr = to_canonical_constant(expr);
assert!(self.constants.insert(var.id, expr).is_none());
@ -307,7 +352,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = remove_spreads(expr);
*c = to_canonical_constant(expr);
Ok(vec![])
}
Err(v) => match self.constants.remove(&v.id) {
@ -351,18 +396,23 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
.collect::<Result<_, _>>()?;
let expression_list = self.fold_expression_list(expression_list)?;
let statements = match expression_list {
TypedExpressionList::EmbedCall(embed, generics, arguments, types) => {
let types = Types {
inner: expression_list
.types
.clone()
.inner
.into_iter()
.map(|t| self.fold_type(t))
.collect::<Result<_, _>>()?,
};
let statements = match expression_list.into_inner() {
TypedExpressionListInner::EmbedCall(embed, generics, arguments) => {
let arguments: Vec<_> = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
let types = types
.into_iter()
.map(|t| self.fold_type(t))
.collect::<Result<_, _>>()?;
fn process_u_from_bits<'ast, T: Field>(
variables: Vec<TypedAssignee<'ast, T>>,
mut arguments: Vec<TypedExpression<'ast, T>>,
@ -373,7 +423,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let argument = arguments.pop().unwrap();
let argument = remove_spreads(argument);
let argument = to_canonical_constant(argument);
match ArrayExpression::try_from(argument)
.unwrap()
@ -586,16 +636,18 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
TypedStatement::Definition(v.clone().into(), c),
TypedStatement::MultipleDefinition(
vec![assignee],
TypedExpressionList::EmbedCall(
embed, generics, arguments, types,
),
TypedExpressionListInner::EmbedCall(
embed, generics, arguments,
)
.annotate(types),
),
],
None => vec![TypedStatement::MultipleDefinition(
vec![assignee],
TypedExpressionList::EmbedCall(
embed, generics, arguments, types,
),
TypedExpressionListInner::EmbedCall(
embed, generics, arguments,
)
.annotate(types),
)],
}
}
@ -607,9 +659,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let def = TypedStatement::MultipleDefinition(
assignees.clone(),
TypedExpressionList::EmbedCall(
embed, generics, arguments, types,
),
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
.annotate(types),
);
let invalidations = assignees.iter().flat_map(|assignee| {
@ -629,27 +680,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
}
}
TypedExpressionList::FunctionCall(key, generics, arguments, types) => {
let generics = generics
TypedExpressionListInner::FunctionCall(function_call) => {
let generics = function_call
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let arguments: Vec<_> = arguments
let arguments: Vec<_> = function_call
.arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
let types = types
.into_iter()
.map(|t| self.fold_type(t))
.collect::<Result<_, _>>()?;
// invalidate the cache for the return assignees as this call mutates them
let def = TypedStatement::MultipleDefinition(
assignees.clone(),
TypedExpressionList::FunctionCall(key, generics, arguments, types),
TypedExpressionList::function_call(
function_call.function_key,
generics,
arguments,
)
.annotate(types),
);
let invalidations = assignees.iter().flat_map(|assignee| {
@ -889,19 +942,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
box e2.annotate(bitwidth),
)),
},
UExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_uint_expression(consequence)?;
let alternative = self.fold_uint_expression(alternative)?;
match self.fold_boolean_expression(condition)? {
BooleanExpression::Value(true) => Ok(consequence.into_inner()),
BooleanExpression::Value(false) => Ok(alternative.into_inner()),
c => Ok(UExpressionInner::IfElse(
box c,
box consequence,
box alternative,
)),
}
}
UExpressionInner::Not(box e) => {
let e = self.fold_uint_expression(e)?.into_inner();
match e {
@ -928,67 +968,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e => Ok(UExpressionInner::Pos(box e.annotate(bitwidth))),
}
}
UExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => {
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(UExpression::try_from(
v.expression_at::<UExpression<'ast, T>>(n as usize)
.unwrap()
.clone(),
)
.unwrap()
.into_inner())
} else {
Err(Error::OutOfBounds(n, size))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(UExpression::try_from(
TypedExpression::try_from(
v.0[n as usize].clone(),
)
.unwrap(),
)
.unwrap()
.into_inner())
}
_ => unreachable!("should be an array value"),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(UExpressionInner::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
)),
}
}
(a, i) => Ok(UExpressionInner::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
)),
}
}
_ => fold_uint_expression_inner(
self,
bitwidth,
UExpressionInner::Select(box array, box index),
),
}
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
@ -1072,110 +1051,112 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
)),
}
}
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_field_expression(consequence)?;
let alternative = self.fold_field_expression(alternative)?;
match self.fold_boolean_expression(condition)? {
BooleanExpression::Value(true) => Ok(consequence),
BooleanExpression::Value(false) => Ok(alternative),
c => Ok(FieldElementExpression::IfElse(
box c,
box consequence,
box alternative,
)),
}
}
FieldElementExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => {
match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(FieldElementExpression::try_from(
v.expression_at::<FieldElementExpression<'ast, T>>(
n as usize,
)
.unwrap()
.clone(),
)
.unwrap())
} else {
Err(Error::OutOfBounds(n, size))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(FieldElementExpression::try_from(
TypedExpression::try_from(
v.0[n as usize].clone(),
)
.unwrap(),
)
.unwrap())
}
_ => unreachable!("should be an array value"),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
)),
}
}
(a, i) => Ok(FieldElementExpression::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
)),
}
}
_ => fold_field_expression(
self,
FieldElementExpression::Select(box array, box index),
),
}
}
FieldElementExpression::Member(box s, m) => {
let s = self.fold_struct_expression(s)?;
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct type"),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == m)
.unwrap()
.1
{
TypedExpression::FieldElement(s) => Ok(s),
_ => unreachable!("should be a field element expression"),
}
}
inner => Ok(FieldElementExpression::Member(
box inner.annotate(members),
m,
)),
}
}
e => fold_field_expression(self, e),
}
}
fn fold_member_expression<
E: Expr<'ast, T> + Member<'ast, T> + From<TypedExpression<'ast, T>>,
>(
&mut self,
_: &E::Ty,
m: MemberExpression<'ast, T, E>,
) -> Result<MemberOrExpression<'ast, T, E>, Self::Error> {
let id = m.id;
let struc = self.fold_struct_expression(*m.struc)?;
let ty = struc.ty().clone();
match struc.into_inner() {
StructExpressionInner::Value(v) => Ok(MemberOrExpression::Expression(
E::from(
ty.members
.iter()
.zip(v)
.find(|(member, _)| member.id == id)
.unwrap()
.1,
)
.into_inner(),
)),
inner => Ok(MemberOrExpression::Member(MemberExpression::new(
inner.annotate(ty),
id,
))),
}
}
fn fold_select_expression<
E: Expr<'ast, T> + Select<'ast, T> + From<TypedExpression<'ast, T>>,
>(
&mut self,
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> Result<SelectOrExpression<'ast, T, E>, Self::Error> {
let array = self.fold_array_expression(*e.array)?;
let index = self.fold_uint_expression(*e.index)?;
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) {
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(SelectOrExpression::Expression(
E::from(
v.expression_at::<StructExpression<'ast, T>>(n as usize)
.unwrap()
.clone(),
)
.into_inner(),
))
} else {
Err(Error::OutOfBounds(n, size))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(SelectOrExpression::Expression(
E::from(
v.expression_at::<StructExpression<'ast, T>>(
n as usize,
)
.unwrap()
.clone(),
)
.into_inner(),
))
}
_ => unreachable!("should be an array value"),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(SelectOrExpression::Expression(
E::select(
ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
UExpressionInner::Value(n).annotate(UBitwidth::B32),
)
.into_inner(),
)),
}
}
(a, i) => Ok(SelectOrExpression::Select(SelectExpression::new(
a.annotate(inner_type, size as u32),
i.annotate(UBitwidth::B32),
))),
},
_ => Ok(SelectOrExpression::Select(SelectExpression::new(
array, index,
))),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
@ -1189,107 +1170,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
},
None => Ok(ArrayExpressionInner::Identifier(id)),
},
ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner())
{
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(ArrayExpression::try_from(
v.expression_at::<ArrayExpression<'ast, T>>(n as usize)
.unwrap()
.clone(),
)
.unwrap()
.into_inner())
} else {
Err(Error::OutOfBounds(n, size))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(ArrayExpression::try_from(
v.expression_at::<ArrayExpression<'ast, T>>(
n as usize,
)
.unwrap()
.clone(),
)
.unwrap()
.into_inner())
}
_ => unreachable!("should be an array value"),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
)),
}
}
(a, i) => Ok(ArrayExpressionInner::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
)),
},
_ => fold_array_expression_inner(
self,
ty,
ArrayExpressionInner::Select(box array, box index),
),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_array_expression(consequence)?;
let alternative = self.fold_array_expression(alternative)?;
match self.fold_boolean_expression(condition)? {
BooleanExpression::Value(true) => Ok(consequence.into_inner()),
BooleanExpression::Value(false) => Ok(alternative.into_inner()),
c => Ok(ArrayExpressionInner::IfElse(
box c,
box consequence,
box alternative,
)),
}
}
ArrayExpressionInner::Member(box struc, id) => {
let struc = self.fold_struct_expression(struc)?;
let members = match struc.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct type"),
};
match struc.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == id)
.unwrap()
.1
{
TypedExpression::Array(a) => Ok(a.into_inner()),
_ => unreachable!("should be an array expression"),
}
}
inner => Ok(ArrayExpressionInner::Member(
box inner.annotate(members),
id,
)),
}
}
e => fold_array_expression_inner(self, ty, e),
}
}
@ -1307,106 +1187,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
},
None => Ok(StructExpressionInner::Identifier(id)),
},
StructExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner())
{
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(StructExpression::try_from(
v.expression_at::<StructExpression<'ast, T>>(n as usize)
.unwrap()
.clone(),
)
.unwrap()
.into_inner())
} else {
Err(Error::OutOfBounds(n, size))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(StructExpression::try_from(
v.expression_at::<StructExpression<'ast, T>>(
n as usize,
)
.unwrap()
.clone(),
)
.unwrap()
.into_inner())
}
_ => unreachable!("should be an array value"),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(StructExpressionInner::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
)),
}
}
(a, i) => Ok(StructExpressionInner::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
)),
},
_ => fold_struct_expression_inner(
self,
ty,
StructExpressionInner::Select(box array, box index),
),
}
}
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_struct_expression(consequence)?;
let alternative = self.fold_struct_expression(alternative)?;
match self.fold_boolean_expression(condition)? {
BooleanExpression::Value(true) => Ok(consequence.into_inner()),
BooleanExpression::Value(false) => Ok(alternative.into_inner()),
c => Ok(StructExpressionInner::IfElse(
box c,
box consequence,
box alternative,
)),
}
}
StructExpressionInner::Member(box s, m) => {
let s = self.fold_struct_expression(s)?;
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct type"),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == m)
.unwrap()
.1
{
TypedExpression::Struct(s) => Ok(s.into_inner()),
_ => unreachable!("should be a struct expression"),
StructExpressionInner::Value(v) => {
let v = v.into_iter().zip(ty.iter()).map(|(v, member)|
match self.fold_expression(v) {
Ok(v) => match (ConcreteType::try_from(v.get_type().clone()), ConcreteType::try_from(*member.ty.clone())) {
(Ok(t1), Ok(t2)) => if t1 == t2 { Ok(v) } else { Err(Error::Type(format!(
"Struct member `{}` in struct `{}/{}` expected to have type `{}`, found type `{}`",
member.id, ty.canonical_location.clone().module.display(), ty.canonical_location.clone().name, t2, t1
))) },
_ => Ok(v)
}
e => e
}
inner => Ok(StructExpressionInner::Member(
box inner.annotate(members),
m,
)),
}
).collect::<Result<_, _>>()?;
Ok(StructExpressionInner::Value(v))
}
e => fold_struct_expression_inner(self, ty, e),
}
@ -1615,98 +1410,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
e => Ok(BooleanExpression::Not(box e)),
}
}
BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_boolean_expression(consequence)?;
let alternative = self.fold_boolean_expression(alternative)?;
match self.fold_boolean_expression(condition)? {
BooleanExpression::Value(true) => Ok(consequence),
BooleanExpression::Value(false) => Ok(alternative),
c => Ok(BooleanExpression::IfElse(
box c,
box consequence,
box alternative,
)),
}
}
BooleanExpression::Select(box array, box index) => {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
let inner_type = array.inner_type().clone();
let size = array.size();
match size.into_inner() {
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner())
{
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
if n < size {
Ok(BooleanExpression::try_from(
v.expression_at::<BooleanExpression<'ast, T>>(n as usize)
.unwrap()
.clone(),
)
.unwrap())
} else {
Err(Error::OutOfBounds(n, size))
}
}
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
match self.constants.get(&id) {
Some(a) => match a {
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => {
Ok(BooleanExpression::try_from(
TypedExpression::try_from(v.0[n as usize].clone())
.unwrap(),
)
.unwrap())
}
_ => unreachable!("should be an array value"),
},
_ => unreachable!("should be an array expression"),
},
None => Ok(BooleanExpression::Select(
box ArrayExpressionInner::Identifier(id)
.annotate(inner_type, size as u32),
box UExpressionInner::Value(n).annotate(UBitwidth::B32),
)),
}
}
(a, i) => Ok(BooleanExpression::Select(
box a.annotate(inner_type, size as u32),
box i.annotate(UBitwidth::B32),
)),
},
_ => fold_boolean_expression(
self,
BooleanExpression::Select(box array, box index),
),
}
}
BooleanExpression::Member(box s, m) => {
let s = self.fold_struct_expression(s)?;
let members = match s.get_type() {
Type::Struct(members) => members,
_ => unreachable!("should be a struct type"),
};
match s.into_inner() {
StructExpressionInner::Value(v) => {
match members
.iter()
.zip(v)
.find(|(member, _)| member.id == m)
.unwrap()
.1
{
TypedExpression::Boolean(s) => Ok(s),
_ => unreachable!("should be a boolean expression"),
}
}
inner => Ok(BooleanExpression::Member(box inner.annotate(members), m)),
}
}
e => fold_boolean_expression(self, e),
}
}
@ -1792,10 +1495,10 @@ mod tests {
#[test]
fn if_else_true() {
let e = FieldElementExpression::IfElse(
box BooleanExpression::Value(true),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(3)),
let e = FieldElementExpression::if_else(
BooleanExpression::Value(true),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
);
assert_eq!(
@ -1806,10 +1509,10 @@ mod tests {
#[test]
fn if_else_false() {
let e = FieldElementExpression::IfElse(
box BooleanExpression::Value(false),
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(3)),
let e = FieldElementExpression::if_else(
BooleanExpression::Value(false),
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(3)),
);
assert_eq!(
@ -1820,8 +1523,8 @@ mod tests {
#[test]
fn select() {
let e = FieldElementExpression::Select(
box ArrayExpressionInner::Value(
let e = FieldElementExpression::select(
ArrayExpressionInner::Value(
vec![
FieldElementExpression::Number(Bn128Field::from(1)).into(),
FieldElementExpression::Number(Bn128Field::from(2)).into(),
@ -1830,7 +1533,7 @@ mod tests {
.into(),
)
.annotate(Type::FieldElement, 3usize),
box UExpressionInner::Add(box 1u32.into(), box 1u32.into())
UExpressionInner::Add(box 1u32.into(), box 1u32.into())
.annotate(UBitwidth::B32),
);

View file

@ -1,73 +0,0 @@
use crate::typed_absy::folder::*;
use crate::typed_absy::*;
use std::collections::HashMap;
use zokrates_field::Field;
pub struct RedefinitionOptimizer<'ast> {
identifiers: HashMap<Identifier<'ast>, Identifier<'ast>>,
}
impl<'ast> RedefinitionOptimizer<'ast> {
fn new() -> Self {
RedefinitionOptimizer {
identifiers: HashMap::new(),
}
}
pub fn optimize<T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
RedefinitionOptimizer::new().fold_program(p)
}
}
fn try_id<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> Option<Identifier<'ast>> {
match e {
TypedExpression::FieldElement(FieldElementExpression::Identifier(id)) => Some(id.clone()),
TypedExpression::Boolean(BooleanExpression::Identifier(id)) => Some(id.clone()),
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Identifier(id) => Some(id.clone()),
_ => None,
},
TypedExpression::Struct(a) => match a.as_inner() {
StructExpressionInner::Identifier(id) => Some(id.clone()),
_ => None,
},
TypedExpression::Uint(a) => match a.as_inner() {
UExpressionInner::Identifier(id) => Some(id.clone()),
_ => None,
},
_ => None,
}
}
impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
self.identifiers = HashMap::new();
fold_function(self, f)
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
let expr = self.fold_expression(expr);
match try_id(&expr) {
Some(id) => {
let target = self.identifiers.get(&id).unwrap_or(&id).clone();
self.identifiers.insert(var.id, target);
vec![]
}
None => vec![TypedStatement::Definition(
TypedAssignee::Identifier(var),
expr,
)],
}
}
s => fold_statement(self, s),
}
}
fn fold_name(&mut self, s: Identifier<'ast>) -> Identifier<'ast> {
self.identifiers.get(&s).cloned().unwrap_or(s)
}
}

View file

@ -29,14 +29,14 @@ use crate::embed::FlatEmbed;
use crate::static_analysis::reducer::Output;
use crate::static_analysis::reducer::ShallowTransformer;
use crate::static_analysis::reducer::Versions;
use crate::typed_absy::types::ConcreteGenericsAssignment;
use crate::typed_absy::types::{ConcreteGenericsAssignment, IntoTypes};
use crate::typed_absy::CoreIdentifier;
use crate::typed_absy::Identifier;
use crate::typed_absy::TypedAssignee;
use crate::typed_absy::{
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature,
Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression,
UExpressionInner, Variable,
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr,
Signature, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, Types,
UExpression, UExpressionInner, Variable,
};
use zokrates_field::Field;
@ -46,13 +46,13 @@ pub enum InlineError<'ast, T> {
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
Types<'ast, T>,
),
NonConstant(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
Types<'ast, T>,
),
}
@ -79,11 +79,11 @@ type InlineResult<'ast, T> = Result<
InlineError<'ast, T>,
>;
pub fn inline_call<'a, 'ast, T: Field>(
pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>(
k: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
output: &E::Ty,
program: &TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
) -> InlineResult<'ast, T> {
@ -91,6 +91,8 @@ pub fn inline_call<'a, 'ast, T: Field>(
use crate::typed_absy::Typed;
let output_types = output.clone().into_types();
// we try to get concrete values for explicit generics
let generics_values: Vec<Option<u32>> = generics
.iter()
@ -117,7 +119,7 @@ pub fn inline_call<'a, 'ast, T: Field>(
let inferred_signature = Signature::new()
.generics(generics.clone())
.inputs(arguments.iter().map(|a| a.get_type()).collect())
.outputs(output_types.clone());
.outputs(output_types.clone().inner);
// we try to get concrete values for the whole signature. if this fails we should propagate again
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {

View file

@ -22,15 +22,12 @@ use crate::typed_absy::Folder;
use std::collections::HashMap;
use crate::typed_absy::{
ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier,
DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression,
StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction,
TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner,
Variable,
ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall,
FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, TypedExpression,
TypedExpressionList, TypedExpressionListInner, TypedFunction, TypedFunctionSymbol, TypedModule,
TypedProgram, TypedStatement, UExpression, UExpressionInner, Variable,
};
use std::convert::{TryFrom, TryInto};
use zokrates_field::Field;
use self::shallow_ssa::ShallowTransformer;
@ -39,6 +36,8 @@ use crate::static_analysis::Propagator;
use std::fmt;
const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20);
// An SSA version map, giving access to the latest version number for each identifier
pub type Versions<'ast> = HashMap<CoreIdentifier<'ast>, usize>;
@ -55,6 +54,7 @@ pub enum Error {
GenericsInMain,
// TODO: give more details about what's blocking the progress
NoProgress,
LoopTooLarge(u128),
}
impl fmt::Display for Error {
@ -66,7 +66,8 @@ impl fmt::Display for Error {
s
),
Error::GenericsInMain => write!(f, "Cannot generate code for generic function"),
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds")
Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"),
Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE),
}
}
}
@ -147,7 +148,7 @@ fn register<'ast>(
) {
for (id, key, value) in substitute
.iter()
.filter_map(|(id, version)| with.get(&id).clone().map(|to| (id, version, to)))
.filter_map(|(id, version)| with.get(&id).map(|to| (id, version, to)))
.filter(|(_, key, value)| key != value)
{
let sub = substitutions.0.entry(id.clone()).or_default();
@ -191,62 +192,66 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
complete: true,
}
}
}
fn fold_function_call<E>(
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
key: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
) -> Result<E, Error>
where
E: FunctionCall<'ast, T> + TryFrom<TypedExpression<'ast, T>, Error = ()> + std::fmt::Debug,
{
let generics = generics
ty: &E::Ty,
e: FunctionCallExpression<'ast, T, E>,
) -> Result<FunctionCallOrExpression<'ast, T, E>, Self::Error> {
let generics = e
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let arguments = arguments
let arguments = e
.arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect::<Result<_, _>>()?;
let res = inline_call(
key.clone(),
let res = inline_call::<_, E>(
e.function_key.clone(),
generics,
arguments,
output_types,
ty,
&self.program,
&mut self.versions,
);
match res {
Ok(Output::Complete((statements, expressions))) => {
Ok(Output::Complete((statements, mut expressions))) => {
self.complete &= true;
self.statement_buffer.extend(statements);
Ok(expressions[0].clone().try_into().unwrap())
Ok(FunctionCallOrExpression::Expression(
E::from(expressions.pop().unwrap()).into_inner(),
))
}
Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => {
self.complete = false;
self.statement_buffer.extend(statements);
self.for_loop_versions_after.extend(delta_for_loop_versions);
Ok(expressions[0].clone().try_into().unwrap())
Ok(FunctionCallOrExpression::Expression(
E::from(expressions[0].clone()).into_inner(),
))
}
Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!(
"Call site `{}` incompatible with declaration `{}`",
conc.to_string(),
decl.to_string()
))),
Err(InlineError::NonConstant(key, generics, arguments, mut output_types)) => {
Err(InlineError::NonConstant(key, generics, arguments, _)) => {
self.complete = false;
Ok(E::function_call(
key,
generics,
arguments,
output_types.pop().unwrap(),
))
Ok(FunctionCallOrExpression::Expression(E::function_call(
key, generics, arguments,
)))
}
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
let identifier = Identifier::from(CoreIdentifier::Call(0)).version(
@ -256,23 +261,48 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
);
let var = Variable::with_id_and_type(identifier, output_types[0].clone());
let var = Variable::with_id_and_type(
identifier.clone(),
output_types.clone().inner.pop().unwrap(),
);
let v = vec![var.clone().into()];
self.statement_buffer
.push(TypedStatement::MultipleDefinition(
v,
TypedExpressionList::EmbedCall(embed, generics, arguments, output_types),
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
.annotate(output_types),
));
Ok(TypedExpression::from(var).try_into().unwrap())
Ok(FunctionCallOrExpression::Expression(E::identifier(
identifier,
)))
}
}
}
}
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
type Error = Error;
fn fold_block_expression<E: ResultFold<'ast, T>>(
&mut self,
b: BlockExpression<'ast, T, E>,
) -> Result<BlockExpression<'ast, T, E>, Self::Error> {
// backup the statements and continue with a fresh state
let statement_buffer = std::mem::take(&mut self.statement_buffer);
let block = fold_block_expression(self, b)?;
// put the original statements back and extract the statements created by visiting the block
let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer);
// return the visited block, augmented with the statements created while visiting it
Ok(BlockExpression {
statements: block
.statements
.into_iter()
.chain(extra_statements)
.collect(),
..block
})
}
fn fold_statement(
&mut self,
@ -281,23 +311,28 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let res = match s {
TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(key, generics, arguments, output_types),
TypedExpressionList {
inner: TypedExpressionListInner::FunctionCall(function_call),
types,
},
) => {
let generics = generics
let generics = function_call
.generics
.into_iter()
.map(|g| g.map(|g| self.fold_uint_expression(g)).transpose())
.collect::<Result<_, _>>()?;
let arguments = arguments
let arguments = function_call
.arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect::<Result<_, _>>()?;
match inline_call(
key,
match inline_call::<_, TypedExpressionList<'ast, T>>(
function_call.function_key,
generics,
arguments,
output_types,
&types,
&self.program,
&mut self.versions,
) {
@ -340,23 +375,15 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
Ok(vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::FunctionCall(
key,
generics,
arguments,
output_types,
),
TypedExpressionList::function_call(key, generics, arguments)
.annotate(output_types),
)])
}
Err(InlineError::Flat(embed, generics, arguments, output_types)) => {
Ok(vec![TypedStatement::MultipleDefinition(
v,
TypedExpressionList::EmbedCall(
embed,
generics,
arguments,
output_types,
),
TypedExpressionListInner::EmbedCall(embed, generics, arguments)
.annotate(output_types),
)])
}
}
@ -383,6 +410,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
let mut transformer = ShallowTransformer::with_versions(&mut self.versions);
if to - from > MAX_FOR_LOOP_SIZE {
return Err(Error::LoopTooLarge(to.saturating_sub(*from)));
}
for index in *from..*to {
let statements: Vec<TypedStatement<_>> =
std::iter::once(TypedStatement::Definition(
@ -429,59 +460,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
res.map(|res| self.statement_buffer.drain(..).chain(res).collect())
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::FunctionCall(key, generics, arguments) => {
self.fold_function_call(key, generics, arguments, vec![Type::Boolean])
}
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression(
&mut self,
e: UExpression<'ast, T>,
) -> Result<UExpression<'ast, T>, Self::Error> {
match e.as_inner() {
UExpressionInner::FunctionCall(key, generics, arguments) => self.fold_function_call(
key.clone(),
generics.clone(),
arguments.clone(),
vec![e.get_type()],
),
_ => fold_uint_expression(self, e),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::FunctionCall(key, generic, arguments) => {
self.fold_function_call(key, generic, arguments, vec![Type::FieldElement])
}
e => fold_field_expression(self, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
array_ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call::<ArrayExpression<_>>(
key.clone(),
generics,
arguments.clone(),
vec![Type::array(ty.clone())],
)
.map(|e| e.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = self.fold_array_expression(array)?;
let from = self.fold_uint_expression(from)?;
@ -497,23 +481,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
}
}
}
_ => fold_array_expression_inner(self, &ty, e),
}
}
fn fold_struct_expression(
&mut self,
e: StructExpression<'ast, T>,
) -> Result<StructExpression<'ast, T>, Self::Error> {
match e.as_inner() {
StructExpressionInner::FunctionCall(key, generics, arguments) => self
.fold_function_call(
key.clone(),
generics.clone(),
arguments.clone(),
vec![e.get_type()],
),
_ => fold_struct_expression(self, e),
_ => fold_array_expression_inner(self, array_ty, e),
}
}
}
@ -547,6 +515,7 @@ pub fn reduce_program<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, E
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -645,13 +614,13 @@ fn compute_hash<T: Field>(f: &TypedFunction<T>) -> u64 {
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::Constant;
use crate::typed_absy::types::DeclarationConstant;
use crate::typed_absy::types::DeclarationSignature;
use crate::typed_absy::{
ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable,
FieldElementExpression, GenericIdentifier, Identifier, OwnedTypedModuleId, Select, Type,
TypedExpression, TypedExpressionList, TypedExpressionOrSpread, UBitwidth, UExpressionInner,
Variable,
ArrayExpression, ArrayExpressionInner, DeclarationFunctionKey, DeclarationType,
DeclarationVariable, FieldElementExpression, GenericIdentifier, Identifier,
OwnedTypedModuleId, Select, Type, TypedExpression, TypedExpressionList,
TypedExpressionOrSpread, Types, UBitwidth, UExpressionInner, Variable,
};
use zokrates_field::Bn128Field;
@ -712,7 +681,7 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -720,8 +689,8 @@ mod tests {
),
vec![],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -761,6 +730,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -826,6 +796,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -863,11 +834,11 @@ mod tests {
)])
.inputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.outputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))]);
let foo: TypedFunction<Bn128Field> = TypedFunction {
@ -908,15 +879,15 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -959,6 +930,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1043,6 +1015,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1080,11 +1053,11 @@ mod tests {
)])
.inputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.outputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))]);
let foo: TypedFunction<Bn128Field> = TypedFunction {
@ -1134,15 +1107,15 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -1185,6 +1158,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1269,6 +1243,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1310,11 +1285,11 @@ mod tests {
let foo_signature = DeclarationSignature::new()
.inputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.outputs(vec![DeclarationType::array((
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
))])
.generics(vec![Some(
GenericIdentifier::with_name("K").index(0).into(),
@ -1324,7 +1299,7 @@ mod tests {
arguments: vec![DeclarationVariable::array(
"a",
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
)
.into()],
statements: vec![
@ -1336,7 +1311,7 @@ mod tests {
)
.into(),
ArrayExpressionInner::Slice(
box ArrayExpressionInner::FunctionCall(
box ArrayExpression::function_call(
DeclarationFunctionKey::with_location("main", "bar")
.signature(foo_signature.clone()),
vec![None],
@ -1388,7 +1363,7 @@ mod tests {
arguments: vec![DeclarationVariable::array(
"a",
DeclarationType::FieldElement,
Constant::Generic(GenericIdentifier::with_name("K").index(0)),
DeclarationConstant::Generic(GenericIdentifier::with_name("K").index(0)),
)
.into()],
statements: vec![TypedStatement::Return(vec![
@ -1407,7 +1382,7 @@ mod tests {
statements: vec![
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
@ -1416,8 +1391,8 @@ mod tests {
)
.annotate(Type::FieldElement, 1u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Return(vec![]),
],
@ -1447,6 +1422,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1558,6 +1534,7 @@ mod tests {
)]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()
@ -1611,15 +1588,15 @@ mod tests {
statements: vec![
TypedStatement::MultipleDefinition(
vec![Variable::array("b", Type::FieldElement, 1u32).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo")
.signature(foo_signature.clone()),
vec![None],
vec![ArrayExpressionInner::Value(vec![].into())
.annotate(Type::FieldElement, 0u32)
.into()],
vec![Type::array((Type::FieldElement, 1u32))],
),
)
.annotate(Types::new(vec![Type::array((Type::FieldElement, 1u32))])),
),
TypedStatement::Return(vec![]),
],
@ -1646,6 +1623,7 @@ mod tests {
]
.into_iter()
.collect(),
constants: Default::default(),
},
)]
.into_iter()

View file

@ -174,88 +174,18 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
res
}
fn fold_field_expression(
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
if let FieldElementExpression::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
ty: &E::Ty,
c: FunctionCallExpression<'ast, T, E>,
) -> FunctionCallOrExpression<'ast, T, E> {
if !c.function_key.id.starts_with('_') {
self.blocked = true;
}
fold_field_expression(self, e)
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
if let BooleanExpression::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_boolean_expression(self, e)
}
fn fold_uint_expression_inner(
&mut self,
b: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
if let UExpressionInner::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_uint_expression_inner(self, b, e)
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
if let ArrayExpressionInner::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_array_expression_inner(self, ty, e)
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
if let StructExpressionInner::FunctionCall(ref k, _, _) = e {
if !k.id.starts_with('_') {
self.blocked = true;
}
};
fold_struct_expression_inner(self, ty, e)
}
fn fold_expression_list(
&mut self,
e: TypedExpressionList<'ast, T>,
) -> TypedExpressionList<'ast, T> {
match e {
TypedExpressionList::FunctionCall(ref k, _, _, _) => {
if !k.id.starts_with('_') {
self.blocked = true;
}
}
_ => unreachable!(),
};
fold_expression_list(self, e)
fold_function_call_expression(self, ty, c)
}
}
@ -440,7 +370,7 @@ mod tests {
let s: TypedStatement<Bn128Field> = TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -448,14 +378,14 @@ mod tests {
),
vec![],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
);
assert_eq!(
u.fold_statement(s),
vec![TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(1)).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo").signature(
DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement])
@ -465,9 +395,9 @@ mod tests {
vec![
FieldElementExpression::Identifier(Identifier::from("a").version(0))
.into()
],
vec![Type::FieldElement],
]
)
.annotate(Types::new(vec![Type::FieldElement]))
)]
);
}
@ -887,14 +817,14 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::field_element("a").into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32),
)],
vec![FieldElementExpression::Identifier("a".into()).into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
),
TypedStatement::Definition(
Variable::uint("n", UBitwidth::B32).into(),
@ -905,7 +835,7 @@ mod tests {
TypedStatement::Definition(
Variable::field_element("a").into(),
(FieldElementExpression::Identifier("a".into())
* FieldElementExpression::FunctionCall(
* FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier("n".into())
@ -962,7 +892,7 @@ mod tests {
),
TypedStatement::MultipleDefinition(
vec![Variable::field_element(Identifier::from("a").version(2)).into()],
TypedExpressionList::FunctionCall(
TypedExpressionList::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier(Identifier::from("n").version(1))
@ -972,8 +902,8 @@ mod tests {
Identifier::from("a").version(1),
)
.into()],
vec![Type::FieldElement],
),
)
.annotate(Types::new(vec![Type::FieldElement])),
),
TypedStatement::Definition(
Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(),
@ -984,7 +914,7 @@ mod tests {
TypedStatement::Definition(
Variable::field_element(Identifier::from("a").version(3)).into(),
(FieldElementExpression::Identifier(Identifier::from("a").version(2))
* FieldElementExpression::FunctionCall(
* FieldElementExpression::function_call(
DeclarationFunctionKey::with_location("main", "foo"),
vec![Some(
UExpressionInner::Identifier(Identifier::from("n").version(2))

View file

@ -405,78 +405,79 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
})
.collect(),
)],
ZirStatement::MultipleDefinition(lhs, rhs) => match rhs {
ZirExpressionList::EmbedCall(embed, generics, arguments) => match embed {
FlatEmbed::U64FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(64) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
) => match embed {
FlatEmbed::U64FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(64) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
FlatEmbed::U32FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(32) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
FlatEmbed::U16FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(16) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
FlatEmbed::U8FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(8) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
_ => vec![ZirStatement::MultipleDefinition(
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(
embed,
generics,
arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
),
)],
},
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
FlatEmbed::U32FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(32) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
FlatEmbed::U16FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(16) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
FlatEmbed::U8FromBits => {
assert_eq!(lhs.len(), 1);
self.register(
lhs[0].clone(),
UMetadata {
max: T::from(2).pow(8) - T::from(1),
should_reduce: ShouldReduce::False,
},
);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(embed, generics, arguments),
)]
}
_ => vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::EmbedCall(
embed,
generics,
arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
),
)],
},
ZirStatement::Assertion(BooleanExpression::UintEq(box left, box right)) => {
let left = self.fold_uint_expression(left);

View file

@ -26,14 +26,16 @@ impl<'ast, T: Field> VariableReadRemover<'ast, T> {
Self::new().fold_program(p)
}
fn select<U: Select<'ast, T> + IfElse<'ast, T>>(
fn select<E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T>>(
&mut self,
a: ArrayExpression<'ast, T>,
i: UExpression<'ast, T>,
) -> U {
e: SelectExpression<'ast, T, E>,
) -> E::Inner {
let a = *e.array;
let i = *e.index;
match i.into_inner() {
UExpressionInner::Value(i) => {
U::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32))
E::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32)).into_inner()
}
i => {
let size = match a.get_type().clone() {
@ -61,7 +63,7 @@ impl<'ast, T: Field> VariableReadRemover<'ast, T> {
(0..size)
.map(|i| {
U::select(
E::select(
a.clone(),
UExpressionInner::Value(i.into()).annotate(UBitwidth::B32),
)
@ -69,7 +71,7 @@ impl<'ast, T: Field> VariableReadRemover<'ast, T> {
.enumerate()
.rev()
.fold(None, |acc, (index, res)| match acc {
Some(acc) => Some(U::if_else(
Some(acc) => Some(E::if_else(
BooleanExpression::UintEq(
box i.clone().annotate(UBitwidth::B32),
box (index as u32).into(),
@ -80,69 +82,21 @@ impl<'ast, T: Field> VariableReadRemover<'ast, T> {
None => Some(res),
})
.unwrap()
.into_inner()
}
}
}
}
impl<'ast, T: Field> Folder<'ast, T> for VariableReadRemover<'ast, T> {
fn fold_field_expression(
fn fold_select_expression<
E: Expr<'ast, T> + Select<'ast, T> + IfElse<'ast, T> + From<TypedExpression<'ast, T>>,
>(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> FieldElementExpression<'ast, T> {
match e {
FieldElementExpression::Select(box a, box i) => self.select(a, i),
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
BooleanExpression::Select(box a, box i) => self.select(a, i),
e => fold_boolean_expression(self, e),
}
}
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> ArrayExpressionInner<'ast, T> {
match e {
ArrayExpressionInner::Select(box a, box i) => {
self.select::<ArrayExpression<'ast, T>>(a, i).into_inner()
}
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> StructExpressionInner<'ast, T> {
match e {
StructExpressionInner::Select(box a, box i) => {
self.select::<StructExpression<'ast, T>>(a, i).into_inner()
}
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> UExpressionInner<'ast, T> {
match e {
UExpressionInner::Select(box a, box i) => {
self.select::<UExpression<'ast, T>>(a, i).into_inner()
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
_: &E::Ty,
e: SelectExpression<'ast, T, E>,
) -> SelectOrExpression<'ast, T, E> {
SelectOrExpression::Expression(self.select(e))
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
@ -167,9 +121,9 @@ mod tests {
let access: TypedStatement<Bn128Field> = TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("b")),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2u32),
box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
FieldElementExpression::select(
ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2u32),
UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
)
.into(),
);
@ -194,15 +148,15 @@ mod tests {
box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32),
box UExpressionInner::Value(0).annotate(UBitwidth::B32)
),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into())
FieldElementExpression::select(
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2u32),
box 0u32.into(),
0u32,
),
FieldElementExpression::Select(
box ArrayExpressionInner::Identifier("a".into())
FieldElementExpression::select(
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 2u32),
box 1u32.into(),
1u32,
)
)
.into()

View file

@ -310,12 +310,13 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Definition(assignee, expr) => {
let expr = self.fold_expression(expr);
if is_constant(&assignee) {
vec![TypedStatement::Definition(assignee, expr)]
} else {
// Note: here we redefine the whole object, ideally we would only redefine some of it
// Example: `a[0][i] = 42` we redefine `a` but we could redefine just `a[0]`
let expr = self.fold_expression(expr);
let (variable, indices) = linear(assignee);

Some files were not shown because too many files have changed in this diff Show more