merge dev, implement inference for neg and pos, fix conclicts
This commit is contained in:
commit
9d13b4129d
203 changed files with 12873 additions and 7237 deletions
|
@ -4,6 +4,7 @@ jobs:
|
|||
build:
|
||||
docker:
|
||||
- image: zokrates/env:latest
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
|
@ -28,6 +29,7 @@ jobs:
|
|||
test:
|
||||
docker:
|
||||
- image: zokrates/env:latest
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
|
@ -42,6 +44,9 @@ jobs:
|
|||
- run:
|
||||
name: Check format
|
||||
command: cargo fmt --all -- --check
|
||||
- run:
|
||||
name: Run clippy
|
||||
command: cargo clippy
|
||||
- run:
|
||||
name: Build
|
||||
command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./build.sh
|
||||
|
@ -80,6 +85,7 @@ jobs:
|
|||
docker:
|
||||
- image: zokrates/env:latest
|
||||
- image: trufflesuite/ganache-cli:next
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
|
|
74
Cargo.lock
generated
74
Cargo.lock
generated
|
@ -133,7 +133,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "3e8cb28c2137af1ef058aa59616db3f7df67dbb70bf2be4ee6920008cc30d98c"
|
||||
dependencies = [
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -145,7 +145,7 @@ dependencies = [
|
|||
"num-bigint 0.4.0",
|
||||
"num-traits 0.2.14",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -205,7 +205,7 @@ checksum = "5ac3d78c750b01f5df5b2e76d106ed31487a93b3868f14a7f0eb3a74f45e1d8a"
|
|||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -653,7 +653,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d"
|
||||
dependencies = [
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -664,7 +664,7 @@ checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
|
|||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -765,7 +765,7 @@ checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4"
|
|||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
|
@ -809,7 +809,7 @@ dependencies = [
|
|||
"num-traits 0.2.14",
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1059,9 +1059,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.49"
|
||||
version = "0.3.50"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc15e39392125075f60c95ba416f5381ff6c3a948ff02ab12464715adf56c821"
|
||||
checksum = "2d99f9e3e84b8f67f846ef5b4cbbc3b1c29f6c759fcbce6f01aa0e73d932a24c"
|
||||
dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
@ -1074,9 +1074,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
|||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.91"
|
||||
version = "0.2.92"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7"
|
||||
checksum = "56d855069fafbb9b344c0f962150cd2c1187975cb1c22c1522c240d8c4986714"
|
||||
|
||||
[[package]]
|
||||
name = "libgit2-sys"
|
||||
|
@ -1370,7 +1370,7 @@ dependencies = [
|
|||
"pest_meta",
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1773,7 +1773,7 @@ checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
|
|||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1866,9 +1866,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.64"
|
||||
version = "1.0.67"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f"
|
||||
checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702"
|
||||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
|
@ -1883,7 +1883,7 @@ checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701"
|
|||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
"unicode-xid 0.2.1",
|
||||
]
|
||||
|
||||
|
@ -2106,9 +2106,9 @@ checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
|
|||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.72"
|
||||
version = "0.2.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fe8f61dba8e5d645a4d8132dc7a0a66861ed5e1045d2c0ed940fab33bac0fbe"
|
||||
checksum = "83240549659d187488f91f33c0f8547cbfef0b2088bc470c116d1d260ef623d9"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"wasm-bindgen-macro",
|
||||
|
@ -2116,24 +2116,24 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.72"
|
||||
version = "0.2.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "046ceba58ff062da072c7cb4ba5b22a37f00a302483f7e2a6cdc18fedbdc1fd3"
|
||||
checksum = "ae70622411ca953215ca6d06d3ebeb1e915f0f6613e3b495122878d7ebec7dae"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-futures"
|
||||
version = "0.4.22"
|
||||
version = "0.4.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73157efb9af26fb564bb59a009afd1c7c334a44db171d280690d0c3faaec3468"
|
||||
checksum = "81b8b767af23de6ac18bf2168b690bed2902743ddf0fb39252e36f9e2bfc63ea"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"js-sys",
|
||||
|
@ -2143,9 +2143,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.72"
|
||||
version = "0.2.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ef9aa01d36cda046f797c57959ff5f3c615c9cc63997a8d545831ec7976819b"
|
||||
checksum = "3e734d91443f177bfdb41969de821e15c516931c3c3db3d318fa1b68975d0f6f"
|
||||
dependencies = [
|
||||
"quote 1.0.9",
|
||||
"wasm-bindgen-macro-support",
|
||||
|
@ -2153,28 +2153,28 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.72"
|
||||
version = "0.2.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96eb45c1b2ee33545a813a92dbb53856418bf7eb54ab34f7f7ff1448a5b3735d"
|
||||
checksum = "d53739ff08c8a68b0fdbcd54c372b8ab800b1449ab3c9d706503bc7dd1621b2c"
|
||||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.72"
|
||||
version = "0.2.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7148f4696fb4960a346eaa60bbfb42a1ac4ebba21f750f75fc1375b098d5ffa"
|
||||
checksum = "d9a543ae66aa233d14bb765ed9af4a33e81b8b58d1584cf1b47ff8cd0b9e4489"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-test"
|
||||
version = "0.3.22"
|
||||
version = "0.3.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f002ea97b5abdb19aafd48cbb5a0a7f6931cf36ea05a0a46ccc95d9f4c2cf43"
|
||||
checksum = "e972e914de63aa53bd84865e54f5c761bd274d48e5be3a6329a662c0386aa67a"
|
||||
dependencies = [
|
||||
"console_error_panic_hook",
|
||||
"js-sys",
|
||||
|
@ -2186,9 +2186,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-test-macro"
|
||||
version = "0.3.22"
|
||||
version = "0.3.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10a6c0bd3933daf64c78fc25a7452530f79fa7e21f77fa03d608d1e988a66735"
|
||||
checksum = "ea6153a8f9bf24588e9f25c87223414fff124049f68d3a442a0f0eab4768a8b6"
|
||||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
|
@ -2196,9 +2196,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.49"
|
||||
version = "0.3.50"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59fe19d70f5dacc03f6e46777213facae5ac3801575d56ca6cbd4c93dcd12310"
|
||||
checksum = "a905d57e488fec8861446d3393670fb50d27a262344013181c2cdf9fff5481be"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
|
@ -2252,7 +2252,7 @@ checksum = "c3f369ddb18862aba61aa49bf31e74d29f0f162dec753063200e1dc084345d16"
|
|||
dependencies = [
|
||||
"proc-macro2 1.0.24",
|
||||
"quote 1.0.9",
|
||||
"syn 1.0.64",
|
||||
"syn 1.0.67",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
|
|
7
asserts.zok
Normal file
7
asserts.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
def id<N>() -> u32:
|
||||
return N
|
||||
|
||||
def main():
|
||||
assert(id::<5>() == 5)
|
||||
assert(id::<6>() == 6)
|
||||
return
|
1
changelogs/unreleased/695-schaeff
Normal file
1
changelogs/unreleased/695-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Introduce constant generics for `u32` values. Introduce literal inference
|
1
changelogs/unreleased/754-schaeff
Normal file
1
changelogs/unreleased/754-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Make embed functions generic, enabling unpacking to any width at minimal cost
|
11
example.zok
Normal file
11
example.zok
Normal file
|
@ -0,0 +1,11 @@
|
|||
def foo<N>(field[N] x) -> field[N]:
|
||||
return x
|
||||
|
||||
def bar<N>(field[N] x) -> field[N]:
|
||||
field[N] r = x
|
||||
return r
|
||||
|
||||
def main(field[3] x) -> field[2]:
|
||||
field[2] z = foo(x)[0..2]
|
||||
|
||||
return bar(z)
|
13
scripts/benchmark.sh
Executable file
13
scripts/benchmark.sh
Executable file
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Usage: benchmark.sh <command>
|
||||
# For MacOS: install gtime with homebrew `brew install gnu-time`
|
||||
|
||||
cmd=$*
|
||||
format="mem=%KK rss=%MK elapsed=%E cpu=%P cpu.sys=%S inputs=%I outputs=%O"
|
||||
|
||||
if command -v gtime; then
|
||||
gtime -f "$format" $cmd
|
||||
else
|
||||
/usr/bin/time -f "$format" $cmd
|
||||
fi
|
|
@ -17,7 +17,7 @@ impl<T: From<usize>> Encode<T> for Inputs<T> {
|
|||
use std::collections::BTreeMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt;
|
||||
use zokrates_core::typed_absy::{Type, UBitwidth};
|
||||
use zokrates_core::typed_absy::types::{ConcreteType, UBitwidth};
|
||||
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -94,18 +94,20 @@ impl<T: Field> fmt::Display for Value<T> {
|
|||
}
|
||||
|
||||
impl<T: Field> Value<T> {
|
||||
fn check(self, ty: Type) -> Result<CheckedValue<T>, String> {
|
||||
fn check(self, ty: ConcreteType) -> Result<CheckedValue<T>, String> {
|
||||
match (self, ty) {
|
||||
(Value::Field(f), Type::FieldElement) => Ok(CheckedValue::Field(f)),
|
||||
(Value::U8(f), Type::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)),
|
||||
(Value::U16(f), Type::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)),
|
||||
(Value::U32(f), Type::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)),
|
||||
(Value::Boolean(b), Type::Boolean) => Ok(CheckedValue::Boolean(b)),
|
||||
(Value::Array(a), Type::Array(array_type)) => {
|
||||
if a.len() != array_type.size {
|
||||
(Value::Field(f), ConcreteType::FieldElement) => Ok(CheckedValue::Field(f)),
|
||||
(Value::U8(f), ConcreteType::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)),
|
||||
(Value::U16(f), ConcreteType::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)),
|
||||
(Value::U32(f), ConcreteType::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)),
|
||||
(Value::Boolean(b), ConcreteType::Boolean) => Ok(CheckedValue::Boolean(b)),
|
||||
(Value::Array(a), ConcreteType::Array(array_type)) => {
|
||||
let size = array_type.size;
|
||||
|
||||
if a.len() != size as usize {
|
||||
Err(format!(
|
||||
"Expected array of size {}, found array of size {}",
|
||||
array_type.size,
|
||||
size,
|
||||
a.len()
|
||||
))
|
||||
} else {
|
||||
|
@ -116,15 +118,16 @@ impl<T: Field> Value<T> {
|
|||
Ok(CheckedValue::Array(a))
|
||||
}
|
||||
}
|
||||
(Value::Struct(mut s), Type::Struct(members)) => {
|
||||
if s.len() != members.len() {
|
||||
(Value::Struct(mut s), ConcreteType::Struct(struc)) => {
|
||||
if s.len() != struc.members_count() {
|
||||
Err(format!(
|
||||
"Expected {} member(s), found {}",
|
||||
members.len(),
|
||||
struc.members_count(),
|
||||
s.len()
|
||||
))
|
||||
} else {
|
||||
let s = members
|
||||
let s = struc
|
||||
.members
|
||||
.into_iter()
|
||||
.map(|member| {
|
||||
s.remove(&member.id)
|
||||
|
@ -167,7 +170,7 @@ impl<T: From<usize>> Encode<T> for CheckedValue<T> {
|
|||
}
|
||||
|
||||
impl<T: Field> Decode<T> for CheckedValues<T> {
|
||||
type Expected = Vec<Type>;
|
||||
type Expected = Vec<ConcreteType>;
|
||||
|
||||
fn decode(raw: Vec<T>, expected: Self::Expected) -> Self {
|
||||
CheckedValues(
|
||||
|
@ -185,23 +188,24 @@ impl<T: Field> Decode<T> for CheckedValues<T> {
|
|||
}
|
||||
|
||||
impl<T: Field> Decode<T> for CheckedValue<T> {
|
||||
type Expected = Type;
|
||||
type Expected = ConcreteType;
|
||||
|
||||
fn decode(raw: Vec<T>, expected: Self::Expected) -> Self {
|
||||
let mut raw = raw;
|
||||
|
||||
match expected {
|
||||
Type::FieldElement => CheckedValue::Field(raw.pop().unwrap()),
|
||||
Type::Uint(UBitwidth::B8) => CheckedValue::U8(
|
||||
ConcreteType::Int => unreachable!(),
|
||||
ConcreteType::FieldElement => CheckedValue::Field(raw.pop().unwrap()),
|
||||
ConcreteType::Uint(UBitwidth::B8) => CheckedValue::U8(
|
||||
u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
|
||||
),
|
||||
Type::Uint(UBitwidth::B16) => CheckedValue::U16(
|
||||
ConcreteType::Uint(UBitwidth::B16) => CheckedValue::U16(
|
||||
u16::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
|
||||
),
|
||||
Type::Uint(UBitwidth::B32) => CheckedValue::U32(
|
||||
ConcreteType::Uint(UBitwidth::B32) => CheckedValue::U32(
|
||||
u32::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
|
||||
),
|
||||
Type::Boolean => {
|
||||
ConcreteType::Boolean => {
|
||||
let v = raw.pop().unwrap();
|
||||
CheckedValue::Boolean(if v == 0.into() {
|
||||
false
|
||||
|
@ -211,12 +215,12 @@ impl<T: Field> Decode<T> for CheckedValue<T> {
|
|||
unreachable!()
|
||||
})
|
||||
}
|
||||
Type::Array(array_type) => CheckedValue::Array(
|
||||
ConcreteType::Array(array_type) => CheckedValue::Array(
|
||||
raw.chunks(array_type.ty.get_primitive_count())
|
||||
.map(|c| CheckedValue::decode(c.to_vec(), *array_type.ty.clone()))
|
||||
.collect(),
|
||||
),
|
||||
Type::Struct(members) => CheckedValue::Struct(
|
||||
ConcreteType::Struct(members) => CheckedValue::Struct(
|
||||
members
|
||||
.into_iter()
|
||||
.scan(0, |state, member| {
|
||||
|
@ -247,9 +251,9 @@ impl<T: Field> TryFrom<serde_json::Value> for Values<T> {
|
|||
match v {
|
||||
serde_json::Value::Array(a) => a
|
||||
.into_iter()
|
||||
.map(|v| Value::try_from(v))
|
||||
.map(Value::try_from)
|
||||
.collect::<Result<_, _>>()
|
||||
.map(|v| Values(v)),
|
||||
.map(Values),
|
||||
v => Err(format!("Expected an array of values, found `{}`", v)),
|
||||
}
|
||||
}
|
||||
|
@ -259,20 +263,22 @@ impl<T: Field> TryFrom<serde_json::Value> for Value<T> {
|
|||
type Error = String;
|
||||
fn try_from(v: serde_json::Value) -> Result<Value<T>, Self::Error> {
|
||||
match v {
|
||||
serde_json::Value::String(s) => T::try_from_dec_str(&s)
|
||||
.map(|v| Value::Field(v))
|
||||
.or_else(|_| match s.len() {
|
||||
4 => u8::from_str_radix(&s[2..], 16)
|
||||
.map(|v| Value::U8(v))
|
||||
.map_err(|_| format!("Expected u8 value, found {}", s)),
|
||||
6 => u16::from_str_radix(&s[2..], 16)
|
||||
.map(|v| Value::U16(v))
|
||||
.map_err(|_| format!("Expected u16 value, found {}", s)),
|
||||
10 => u32::from_str_radix(&s[2..], 16)
|
||||
.map(|v| Value::U32(v))
|
||||
.map_err(|_| format!("Expected u32 value, found {}", s)),
|
||||
_ => Err(format!("Cannot parse {} to any type", s)),
|
||||
}),
|
||||
serde_json::Value::String(s) => {
|
||||
T::try_from_dec_str(&s)
|
||||
.map(Value::Field)
|
||||
.or_else(|_| match s.len() {
|
||||
4 => u8::from_str_radix(&s[2..], 16)
|
||||
.map(Value::U8)
|
||||
.map_err(|_| format!("Expected u8 value, found {}", s)),
|
||||
6 => u16::from_str_radix(&s[2..], 16)
|
||||
.map(Value::U16)
|
||||
.map_err(|_| format!("Expected u16 value, found {}", s)),
|
||||
10 => u32::from_str_radix(&s[2..], 16)
|
||||
.map(Value::U32)
|
||||
.map_err(|_| format!("Expected u32 value, found {}", s)),
|
||||
_ => Err(format!("Cannot parse {} to any type", s)),
|
||||
})
|
||||
}
|
||||
serde_json::Value::Bool(b) => Ok(Value::Boolean(b)),
|
||||
serde_json::Value::Number(n) => Err(format!(
|
||||
"Value `{}` isn't allowed, did you mean `\"{}\"`?",
|
||||
|
@ -280,14 +286,14 @@ impl<T: Field> TryFrom<serde_json::Value> for Value<T> {
|
|||
)),
|
||||
serde_json::Value::Array(a) => a
|
||||
.into_iter()
|
||||
.map(|v| Value::try_from(v))
|
||||
.map(Value::try_from)
|
||||
.collect::<Result<_, _>>()
|
||||
.map(|v| Value::Array(v)),
|
||||
.map(Value::Array),
|
||||
serde_json::Value::Object(o) => o
|
||||
.into_iter()
|
||||
.map(|(k, v)| Value::try_from(v).map(|v| (k, v)))
|
||||
.collect::<Result<Map<_, _>, _>>()
|
||||
.map(|v| Value::Struct(v)),
|
||||
.map(Value::Struct),
|
||||
v => Err(format!("Value `{}` isn't allowed", v)),
|
||||
}
|
||||
}
|
||||
|
@ -320,10 +326,13 @@ impl<T: Field> Into<serde_json::Value> for CheckedValues<T> {
|
|||
fn parse<T: Field>(s: &str) -> Result<Values<T>, Error> {
|
||||
let json_values: serde_json::Value =
|
||||
serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?;
|
||||
Values::try_from(json_values).map_err(|e| Error::Conversion(e))
|
||||
Values::try_from(json_values).map_err(Error::Conversion)
|
||||
}
|
||||
|
||||
pub fn parse_strict<T: Field>(s: &str, types: Vec<Type>) -> Result<CheckedValues<T>, Error> {
|
||||
pub fn parse_strict<T: Field>(
|
||||
s: &str,
|
||||
types: Vec<ConcreteType>,
|
||||
) -> Result<CheckedValues<T>, Error> {
|
||||
let parsed = parse(s)?;
|
||||
if parsed.0.len() != types.len() {
|
||||
return Err(Error::Type(format!(
|
||||
|
@ -338,7 +347,7 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<Type>) -> Result<CheckedValues
|
|||
.zip(types.into_iter())
|
||||
.map(|(v, ty)| v.check(ty))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|e| Error::Type(e))?;
|
||||
.map_err(Error::Type)?;
|
||||
Ok(CheckedValues(checked))
|
||||
}
|
||||
|
||||
|
@ -403,14 +412,19 @@ mod tests {
|
|||
|
||||
mod strict {
|
||||
use super::*;
|
||||
use zokrates_core::typed_absy::types::{StructMember, StructType};
|
||||
use zokrates_core::typed_absy::types::{
|
||||
ConcreteStructMember, ConcreteStructType, ConcreteType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn fields() {
|
||||
let s = r#"["1", "2"]"#;
|
||||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(s, vec![Type::FieldElement, Type::FieldElement])
|
||||
.unwrap(),
|
||||
parse_strict::<Bn128Field>(
|
||||
s,
|
||||
vec![ConcreteType::FieldElement, ConcreteType::FieldElement]
|
||||
)
|
||||
.unwrap(),
|
||||
CheckedValues(vec![
|
||||
CheckedValue::Field(1.into()),
|
||||
CheckedValue::Field(2.into())
|
||||
|
@ -422,7 +436,8 @@ mod tests {
|
|||
fn bools() {
|
||||
let s = "[true, false]";
|
||||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(s, vec![Type::Boolean, Type::Boolean]).unwrap(),
|
||||
parse_strict::<Bn128Field>(s, vec![ConcreteType::Boolean, ConcreteType::Boolean])
|
||||
.unwrap(),
|
||||
CheckedValues(vec![
|
||||
CheckedValue::Boolean(true),
|
||||
CheckedValue::Boolean(false)
|
||||
|
@ -434,7 +449,11 @@ mod tests {
|
|||
fn array() {
|
||||
let s = "[[true, false]]";
|
||||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(s, vec![Type::array(Type::Boolean, 2)]).unwrap(),
|
||||
parse_strict::<Bn128Field>(
|
||||
s,
|
||||
vec![ConcreteType::array((ConcreteType::Boolean, 2usize))]
|
||||
)
|
||||
.unwrap(),
|
||||
CheckedValues(vec![CheckedValue::Array(vec![
|
||||
CheckedValue::Boolean(true),
|
||||
CheckedValue::Boolean(false)
|
||||
|
@ -448,10 +467,13 @@ mod tests {
|
|||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(
|
||||
s,
|
||||
vec![Type::Struct(StructType::new(
|
||||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![StructMember::new("a".into(), Type::FieldElement)]
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
)]
|
||||
))]
|
||||
)
|
||||
.unwrap(),
|
||||
|
@ -466,10 +488,13 @@ mod tests {
|
|||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(
|
||||
s,
|
||||
vec![Type::Struct(StructType::new(
|
||||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![StructMember::new("a".into(), Type::FieldElement)]
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
)]
|
||||
))]
|
||||
)
|
||||
.unwrap_err(),
|
||||
|
@ -480,10 +505,13 @@ mod tests {
|
|||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(
|
||||
s,
|
||||
vec![Type::Struct(StructType::new(
|
||||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![StructMember::new("a".into(), Type::FieldElement)]
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
)]
|
||||
))]
|
||||
)
|
||||
.unwrap_err(),
|
||||
|
@ -494,10 +522,13 @@ mod tests {
|
|||
assert_eq!(
|
||||
parse_strict::<Bn128Field>(
|
||||
s,
|
||||
vec![Type::Struct(StructType::new(
|
||||
vec![ConcreteType::Struct(ConcreteStructType::new(
|
||||
"".into(),
|
||||
"".into(),
|
||||
vec![StructMember::new("a".into(), Type::FieldElement)]
|
||||
vec![ConcreteStructMember::new(
|
||||
"a".into(),
|
||||
ConcreteType::FieldElement
|
||||
)]
|
||||
))]
|
||||
)
|
||||
.unwrap_err(),
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
- [Control flow](language/control_flow.md)
|
||||
- [Imports](language/imports.md)
|
||||
- [Comments](language/comments.md)
|
||||
- [Generics](language/generics.md)
|
||||
- [Macros](language/macros.md)
|
||||
|
||||
- [Toolbox](toolbox/index.md)
|
||||
|
|
|
@ -12,6 +12,12 @@ Arguments are passed by value.
|
|||
{{#include ../../../zokrates_cli/examples/book/side_effects.zok}}
|
||||
```
|
||||
|
||||
Generic paramaters, if any, must be compile-time constants. They are inferred by the compiler if that is possible, but can also be provided explicitly.
|
||||
|
||||
```zokrates
|
||||
{{#include ../../../zokrates_cli/examples/book/generic_call.zok}}
|
||||
```
|
||||
|
||||
### If-expressions
|
||||
|
||||
An if-expression allows you to branch your code depending on a boolean condition.
|
||||
|
@ -28,7 +34,7 @@ For loops are available with the following syntax:
|
|||
{{#include ../../../zokrates_cli/examples/book/for.zok}}
|
||||
```
|
||||
|
||||
The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs.
|
||||
The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. They can depend on generic parameters.
|
||||
|
||||
### Assertions
|
||||
|
||||
|
|
|
@ -7,7 +7,14 @@ A function has to be declared at the top level before it is called.
|
|||
```
|
||||
|
||||
A function's signature has to be explicitly provided.
|
||||
Functions can return many values by providing them as a comma-separated list.
|
||||
|
||||
A function can be generic over any number of values of type `u32`.
|
||||
|
||||
```zokrates
|
||||
{{#include ../../../zokrates_cli/examples/book/generic_function_declaration.zok}}
|
||||
```
|
||||
|
||||
Functions can return multiple values by providing them as a comma-separated list.
|
||||
|
||||
```zokrates
|
||||
{{#include ../../../zokrates_cli/examples/book/multi_return.zok}}
|
||||
|
|
7
zokrates_book/src/language/generics.md
Normal file
7
zokrates_book/src/language/generics.md
Normal file
|
@ -0,0 +1,7 @@
|
|||
## Generics
|
||||
|
||||
ZoKrates supports code that is generic over constants of the `u32` type. No specific keyword is used: the compiler determines if the generic parameters are indeed constant at compile time. Here's an example of generic code in ZoKrates:
|
||||
|
||||
```zokrates
|
||||
{{#include ../../../zokrates_cli/examples/book/generics.zok}}
|
||||
```
|
|
@ -8,7 +8,7 @@ ZoKrates currently exposes two primitive types and two complex types:
|
|||
|
||||
This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers.
|
||||
|
||||
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/reference/proving_schemes.html#alt_bn128) curve supported by Ethereum.
|
||||
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/toolbox/proving_schemes.html#alt_bn128) curve supported by Ethereum.
|
||||
|
||||
While `field` values mostly behave like unsigned integers, one should keep in mind that they overflow at `p` and not some power of 2, so that we have:
|
||||
|
||||
|
@ -32,13 +32,23 @@ Similarly to booleans, unsigned integer inputs of the main function only accept
|
|||
|
||||
The division operation calculates the standard floor division for integers. The `%` operand can be used to obtain the remainder.
|
||||
|
||||
### Numeric inference
|
||||
|
||||
In the case of decimal literals like `42`, the compiler tries to find the appropriate type (`field`, `u8`, `u16` or `u32`) depending on the context. If it cannot converge to a single option, an error is returned. This means that there is no default type for decimal literals.
|
||||
|
||||
All operations between literals have the semantics of the infered type.
|
||||
|
||||
```zokrates
|
||||
{{#include ../../../zokrates_cli/examples/book/numeric_inference.zok}}
|
||||
```
|
||||
|
||||
## Complex Types
|
||||
|
||||
ZoKrates provides two complex types: arrays and structs.
|
||||
|
||||
### Arrays
|
||||
|
||||
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time.
|
||||
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. For more details on generic array sizes, see [constant generics](/language/generics.html)
|
||||
Arrays can contain elements of any type and have arbitrary dimensions.
|
||||
|
||||
The following example code shows examples of how to use arrays:
|
||||
|
|
|
@ -11,5 +11,5 @@ fn export_stdlib() {
|
|||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let mut options = CopyOptions::new();
|
||||
options.overwrite = true;
|
||||
copy_items(&vec!["tests/contract"], out_dir, &options).unwrap();
|
||||
copy_items(&["tests/contract"], out_dir, &options).unwrap();
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
def main() -> field:
|
||||
field[3] a = [1, 2, 3]
|
||||
field c = 0
|
||||
for field i in 0..3 do
|
||||
def main() -> u32:
|
||||
u32[3] a = [1, 2, 3]
|
||||
u32 c = 0
|
||||
for u32 i in 0..3 do
|
||||
c = c + a[i]
|
||||
endfor
|
||||
return c
|
|
@ -1,7 +1,7 @@
|
|||
def main() -> (field[3]):
|
||||
field[3] a = [1, 2, 3]
|
||||
field[3] c = [4, 5, 6]
|
||||
for field i in 0..3 do
|
||||
def main() -> (u32[3]):
|
||||
u32[3] a = [1, 2, 3]
|
||||
u32[3] c = [4, 5, 6]
|
||||
for u32 i in 0..3 do
|
||||
c[i] = c[i] + a[i]
|
||||
endfor
|
||||
return c
|
|
@ -3,7 +3,7 @@ def main(bool[3] a) -> (field[3]):
|
|||
a[1] = true || a[2]
|
||||
a[2] = a[0]
|
||||
field[3] result = [0; 3]
|
||||
for field i in 0..3 do
|
||||
for u32 i in 0..3 do
|
||||
result[i] = if a[i] then 33 else 0 fi
|
||||
endfor
|
||||
return result
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
def main(field[2][2][2] cube) -> field:
|
||||
field res = 0
|
||||
|
||||
for field i in 0..2 do
|
||||
for field j in 0..2 do
|
||||
for field k in 0..2 do
|
||||
for u32 i in 0..2 do
|
||||
for u32 j in 0..2 do
|
||||
for u32 k in 0..2 do
|
||||
res = res + cube[i][j][k]
|
||||
endfor
|
||||
endfor
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
def main(field index, field[5] array) -> field:
|
||||
def main(u32 index, field[5] array) -> field:
|
||||
return array[index]
|
|
@ -1,4 +1,4 @@
|
|||
def main(field[10][10][10] a, field i, field j, field k) -> (field[3]):
|
||||
def main(field[10][10][10] a, u32 i, u32 j, u32 k) -> (field[3]):
|
||||
a[i][j][k] = 42
|
||||
field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
|
||||
return b[0]
|
4
zokrates_cli/examples/arrays/repeat.zok
Normal file
4
zokrates_cli/examples/arrays/repeat.zok
Normal file
|
@ -0,0 +1,4 @@
|
|||
def main(field a) -> field[4]:
|
||||
u32 SIZE = 4
|
||||
field[SIZE] res = [a; SIZE]
|
||||
return res
|
|
@ -1,6 +1,6 @@
|
|||
def slice32from(field offset, field[2048] input) -> (field[32]):
|
||||
def slice32from(u32 offset, field[2048] input) -> (field[32]):
|
||||
field[32] result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
for field i in 0..32 do
|
||||
for u32 i in 0..32 do
|
||||
result[i] = input[offset + i]
|
||||
endfor
|
||||
return result
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
def get(field[32] array, field index) -> field:
|
||||
def get(field[32] array, u32 index) -> field:
|
||||
return array[index]
|
||||
|
||||
def main() -> field:
|
||||
|
|
|
@ -5,4 +5,6 @@ def main() -> field:
|
|||
field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4
|
||||
field[2] d = a[1..3] // initialize an array copying a slice from `a`
|
||||
bool[3] e = [true, true || false, true] // initialize a boolean array
|
||||
u32 SIZE = 3
|
||||
field[SIZE] f = [1, 2, 3] // initialize a field array with a size that's a compile-time constant
|
||||
return a[0] + b[1] + c[2]
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
def main() -> ():
|
||||
assert(1 == 2)
|
||||
assert(1f == 2f)
|
||||
return
|
|
@ -1,7 +1,7 @@
|
|||
def main() -> field:
|
||||
field res = 0
|
||||
for field i in 0..4 do
|
||||
for field j in i..5 do
|
||||
def main() -> u32:
|
||||
u32 res = 0
|
||||
for u32 i in 0..4 do
|
||||
for u32 j in i..5 do
|
||||
res = res + i
|
||||
endfor
|
||||
endfor
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
def main() -> field:
|
||||
field a = 0
|
||||
for field i in 0..5 do
|
||||
def main() -> u32:
|
||||
u32 a = 0
|
||||
for u32 i in 0..5 do
|
||||
a = a + i
|
||||
endfor
|
||||
// return i <- not allowed
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
def foo() -> field:
|
||||
return 1
|
||||
def foo(field a, field b) -> field:
|
||||
return a + b
|
||||
|
||||
def main() -> field:
|
||||
return foo()
|
||||
return foo(1, 2)
|
7
zokrates_cli/examples/book/generic_call.zok
Normal file
7
zokrates_cli/examples/book/generic_call.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
def foo<N, P>() -> field[P]:
|
||||
return [42; P]
|
||||
|
||||
def main() -> field[2]:
|
||||
// `P` is inferred from the declaration of `res`, while `N` is provided explicitly
|
||||
field[2] res = foo::<3, _>()
|
||||
return res
|
|
@ -0,0 +1,6 @@
|
|||
def foo<N>() -> field[N]:
|
||||
return [42; N]
|
||||
|
||||
def main() -> field[2]:
|
||||
field[2] res = foo()
|
||||
return res
|
9
zokrates_cli/examples/book/generics.zok
Normal file
9
zokrates_cli/examples/book/generics.zok
Normal file
|
@ -0,0 +1,9 @@
|
|||
def sum<N>(field[N] a) -> field:
|
||||
field res = 0
|
||||
for u32 i in 0..N do
|
||||
res = res + a[i]
|
||||
endfor
|
||||
return res
|
||||
|
||||
def main(field[3] a) -> field:
|
||||
return sum(a)
|
|
@ -1,7 +1,7 @@
|
|||
def main() -> field:
|
||||
field a = 2
|
||||
// field a = 3 <- not allowed
|
||||
for field i in 0..5 do
|
||||
for u32 i in 0..5 do
|
||||
// field a = 7 <- not allowed
|
||||
endfor
|
||||
return a
|
9
zokrates_cli/examples/book/numeric_inference.zok
Normal file
9
zokrates_cli/examples/book/numeric_inference.zok
Normal file
|
@ -0,0 +1,9 @@
|
|||
def main():
|
||||
// `255` is infered to `255f`, and the addition happens between field elements
|
||||
assert(255 + 1f == 256)
|
||||
|
||||
// `255` is infered to `255u8`, and the addition happens between u8
|
||||
// This causes an overflow
|
||||
assert(255 + 1u8 == 0)
|
||||
|
||||
return
|
|
@ -0,0 +1,5 @@
|
|||
import "EMBED/u8_to_bits" as u8_to_bits
|
||||
|
||||
def main(u8 x):
|
||||
bool[32] b = u8_to_bits(x) // note the incorrect array length on the left
|
||||
return
|
|
@ -0,0 +1,8 @@
|
|||
def foo<N>(field[N] a) -> bool:
|
||||
field[3] b = a
|
||||
return true
|
||||
|
||||
|
||||
def main(field[1] a):
|
||||
assert(foo(a))
|
||||
return
|
|
@ -0,0 +1,3 @@
|
|||
def main():
|
||||
assert([1] == [1, 2])
|
||||
return
|
|
@ -0,0 +1,2 @@
|
|||
def main<P>(field[P] a):
|
||||
return
|
|
@ -0,0 +1,5 @@
|
|||
def foo<P>(field[P] a, field[P] b) -> field:
|
||||
return 42
|
||||
|
||||
def main() -> field:
|
||||
return foo([1, 2], [1])
|
|
@ -0,0 +1,3 @@
|
|||
def main():
|
||||
assert([[1]] == [1, 2])
|
||||
return
|
2
zokrates_cli/examples/compile_errors/generics/unused.zok
Normal file
2
zokrates_cli/examples/compile_errors/generics/unused.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def main<P>():
|
||||
return
|
|
@ -1,4 +1,4 @@
|
|||
def main() -> field:
|
||||
for field i in 0..5 do
|
||||
for u32 i in 0..5 do
|
||||
endfor
|
||||
return i
|
|
@ -14,6 +14,6 @@ def main(field a) -> field:
|
|||
assert(2 * b == a * 12 + 60)
|
||||
field c = 7 * (b + a)
|
||||
assert(isEqual(c, 7 * b + 7 * a))
|
||||
field k = if [1, 2] == [3, 4] then 1 else 3 fi
|
||||
field k = if [1f, 2] == [3f, 4] then 1 else 3 fi
|
||||
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
|
||||
return b + c
|
|
@ -1,6 +1,10 @@
|
|||
def bound(field x) -> u32:
|
||||
return 41 + 1
|
||||
|
||||
def main(field a) -> field:
|
||||
field x = 7
|
||||
for field i in 0..10 do
|
||||
x = x + 1
|
||||
for u32 i in 0..bound(x) do
|
||||
// x = x + a
|
||||
x = x + a
|
||||
endfor
|
||||
|
|
|
@ -4,14 +4,14 @@ def lt(field a,field b) -> bool:
|
|||
def cutoff() -> field:
|
||||
return 31337
|
||||
|
||||
def getThing(field index) -> field:
|
||||
def getThing(u32 index) -> field:
|
||||
field[6] a = [13, 23, 43, 53, 73, 83]
|
||||
return a[index]
|
||||
|
||||
def cubeThing(field thing) -> field:
|
||||
return thing**3
|
||||
|
||||
def main(field index) -> bool:
|
||||
def main(u32 index) -> bool:
|
||||
field thing = getThing(index)
|
||||
thing = cubeThing(thing)
|
||||
return lt(cutoff(), thing)
|
||||
|
|
|
@ -4,9 +4,6 @@ import "ecc/babyjubjubParams" as context
|
|||
from "ecc/babyjubjubParams" import BabyJubJubParams
|
||||
import "hashes/utils/256bitsDirectionHelper" as multiplex
|
||||
|
||||
def multiplex(bool selector, u32[8] left, u32[8] right) -> (u32[8]):
|
||||
return if selector then right else left fi
|
||||
|
||||
// Merke-Tree inclusion proof for tree depth 3 using SNARK efficient pedersen hashes
|
||||
// directionSelector=> true if current digest is on the rhs of the hash
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import "utils/casts/u32_to_field" as to_field
|
||||
|
||||
// Binomial Coeffizient, n!/(k!*(n-k)!).
|
||||
def fac(field x) -> field:
|
||||
field f = 1
|
||||
field counter = 0
|
||||
for field i in 1..100 do
|
||||
f = if counter == x then f else f * i fi
|
||||
for u32 i in 1..100 do
|
||||
f = if counter == x then f else f * to_field(i) fi
|
||||
counter = if counter == x then counter else counter + 1 fi
|
||||
endfor
|
||||
return f
|
||||
|
|
|
@ -2,7 +2,7 @@ def main() -> field:
|
|||
field a = 1 + 2 + 3
|
||||
field b = if 1 < a then 3 else a + 3 fi
|
||||
field c = if b + a == 2 then 1 else b fi
|
||||
for field e in 0..2 do
|
||||
for u32 e in 0..2 do
|
||||
field g = 4
|
||||
c = c + g
|
||||
endfor
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
def main() -> field:
|
||||
field a = 2
|
||||
return 2**(a**2 + 2)
|
||||
u32 a = 2
|
||||
return 2**(a * 2 + 2)
|
|
@ -32,26 +32,26 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva
|
|||
bool res = true
|
||||
|
||||
// go through the whole grid and check that all elements are valid
|
||||
for field i in 0..4 do
|
||||
for field j in 0..4 do
|
||||
for u32 i in 0..4 do
|
||||
for u32 j in 0..4 do
|
||||
res = res && validateInput(a[i][j])
|
||||
endfor
|
||||
endfor
|
||||
|
||||
// go through the 4 2x2 boxes and check that they do not contain duplicates
|
||||
for field i in 0..1 do
|
||||
for field j in 0..1 do
|
||||
for u32 i in 0..1 do
|
||||
for u32 j in 0..1 do
|
||||
res = res && checkNoDuplicates(a[2*i][2*i], a[2*i][2*i + 1], a[2*i + 1][2*i], a[2*i + 1][2*i + 1])
|
||||
endfor
|
||||
endfor
|
||||
|
||||
// go through the 4 rows and check that they do not contain duplicates
|
||||
for field i in 0..4 do
|
||||
for u32 i in 0..4 do
|
||||
res = res && checkNoDuplicates(a[i][0], a[i][1], a[i][2], a[i][3])
|
||||
endfor
|
||||
|
||||
// go through the 4 columns and check that they do not contain duplicates
|
||||
for field j in 0..4 do
|
||||
for u32 j in 0..4 do
|
||||
res = res && checkNoDuplicates(a[0][j], a[1][j], a[2][j], a[3][j])
|
||||
endfor
|
||||
|
||||
|
|
|
@ -10,6 +10,6 @@ def isWaldo(field a, field p, field q) -> bool:
|
|||
return a == p * q
|
||||
|
||||
// define all
|
||||
def main(field[3] a, private field index, private field p, private field q) -> bool:
|
||||
def main(field[3] a, private u32 index, private field p, private field q) -> bool:
|
||||
// prover provides the index of Waldo
|
||||
return isWaldo(a[index], p, q)
|
|
@ -32,10 +32,13 @@ fn cli() -> Result<(), String> {
|
|||
compile::subcommand(),
|
||||
check::subcommand(),
|
||||
compute_witness::subcommand(),
|
||||
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
|
||||
setup::subcommand(),
|
||||
export_verifier::subcommand(),
|
||||
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
|
||||
generate_proof::subcommand(),
|
||||
print_proof::subcommand(),
|
||||
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
|
||||
verify::subcommand()])
|
||||
.get_matches();
|
||||
|
||||
|
@ -140,7 +143,7 @@ mod tests {
|
|||
let interpreter = ir::Interpreter::default();
|
||||
|
||||
let _ = interpreter
|
||||
.execute(&artifacts.prog(), &vec![Bn128Field::from(0)])
|
||||
.execute(&artifacts.prog(), &[Bn128Field::from(0)])
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -169,7 +172,7 @@ mod tests {
|
|||
|
||||
let interpreter = ir::Interpreter::default();
|
||||
|
||||
let res = interpreter.execute(&artifacts.prog(), &vec![Bn128Field::from(0)]);
|
||||
let res = interpreter.execute(&artifacts.prog(), &[Bn128Field::from(0)]);
|
||||
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ lazy_static! {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
|
||||
pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") {
|
||||
if cfg!(feature = "ark") {
|
||||
if cfg!(feature = "bellman") {
|
||||
|
@ -26,27 +27,21 @@ pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") {
|
|||
} else {
|
||||
&[LIBSNARK, ARK]
|
||||
}
|
||||
} else if cfg!(feature = "bellman") {
|
||||
&[BELLMAN, LIBSNARK]
|
||||
} else {
|
||||
if cfg!(feature = "bellman") {
|
||||
&[BELLMAN, LIBSNARK]
|
||||
} else {
|
||||
&[LIBSNARK]
|
||||
}
|
||||
&[LIBSNARK]
|
||||
}
|
||||
} else if cfg!(feature = "ark") {
|
||||
if cfg!(feature = "bellman") {
|
||||
&[BELLMAN, ARK]
|
||||
} else {
|
||||
&[ARK]
|
||||
}
|
||||
} else if cfg!(feature = "bellman") {
|
||||
&[BELLMAN]
|
||||
} else {
|
||||
if cfg!(feature = "ark") {
|
||||
if cfg!(feature = "bellman") {
|
||||
&[BELLMAN, ARK]
|
||||
} else {
|
||||
&[ARK]
|
||||
}
|
||||
} else {
|
||||
if cfg!(feature = "bellman") {
|
||||
&[BELLMAN]
|
||||
} else {
|
||||
&[]
|
||||
}
|
||||
}
|
||||
&[]
|
||||
};
|
||||
|
||||
pub const BN128: &str = "bn128";
|
||||
|
|
|
@ -69,7 +69,7 @@ fn cli_check<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
format!(
|
||||
"{}:{}",
|
||||
file.strip_prefix(std::env::current_dir().unwrap())
|
||||
.unwrap_or(file.as_path())
|
||||
.unwrap_or_else(|_| file.as_path())
|
||||
.display(),
|
||||
e.value()
|
||||
)
|
||||
|
|
|
@ -92,9 +92,9 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
let fmt_error = |e: &CompileError| {
|
||||
let file = e.file().canonicalize().unwrap();
|
||||
format!(
|
||||
"{}:{}",
|
||||
"{}: {}",
|
||||
file.strip_prefix(std::env::current_dir().unwrap())
|
||||
.unwrap_or(file.as_path())
|
||||
.unwrap_or_else(|_| file.as_path())
|
||||
.display(),
|
||||
e.value()
|
||||
)
|
||||
|
@ -154,7 +154,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
|
|||
.map_err(|why| format!("Couldn't create {}: {}", hr_output_path.display(), why))?;
|
||||
|
||||
let mut hrofb = BufWriter::new(hr_output_file);
|
||||
write!(&mut hrofb, "{}\n", program_flattened)
|
||||
writeln!(&mut hrofb, "{}", program_flattened)
|
||||
.map_err(|_| "Unable to write data to file".to_string())?;
|
||||
hrofb
|
||||
.flush()
|
||||
|
|
|
@ -8,7 +8,7 @@ use zokrates_abi::Encode;
|
|||
use zokrates_core::ir;
|
||||
use zokrates_core::ir::ProgEnum;
|
||||
use zokrates_core::typed_absy::abi::Abi;
|
||||
use zokrates_core::typed_absy::{Signature, Type};
|
||||
use zokrates_core::typed_absy::types::{ConcreteSignature, ConcreteType};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub fn subcommand() -> App<'static, 'static> {
|
||||
|
@ -106,9 +106,12 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
|
|||
|
||||
abi.signature()
|
||||
}
|
||||
false => Signature::new()
|
||||
.inputs(vec![Type::FieldElement; ir_prog.main.arguments.len()])
|
||||
.outputs(vec![Type::FieldElement; ir_prog.main.returns.len()]),
|
||||
false => ConcreteSignature::new()
|
||||
.inputs(vec![
|
||||
ConcreteType::FieldElement;
|
||||
ir_prog.main.arguments.len()
|
||||
])
|
||||
.outputs(vec![ConcreteType::FieldElement; ir_prog.main.returns.len()]),
|
||||
};
|
||||
|
||||
use zokrates_abi::Inputs;
|
||||
|
@ -123,8 +126,8 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
|
|||
a.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string()))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.unwrap_or(Ok(vec![]))
|
||||
.map(|v| Inputs::Raw(v))
|
||||
.unwrap_or_else(|| Ok(vec![]))
|
||||
.map(Inputs::Raw)
|
||||
}
|
||||
// take stdin arguments
|
||||
true => {
|
||||
|
@ -137,7 +140,7 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
|
|||
use zokrates_abi::parse_strict;
|
||||
|
||||
parse_strict(&input, signature.inputs)
|
||||
.map(|parsed| Inputs::Abi(parsed))
|
||||
.map(Inputs::Abi)
|
||||
.map_err(|why| why.to_string())
|
||||
}
|
||||
Err(_) => Err(String::from("???")),
|
||||
|
@ -148,10 +151,10 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
|
|||
Ok(_) => {
|
||||
input.retain(|x| x != '\n');
|
||||
input
|
||||
.split(" ")
|
||||
.split(' ')
|
||||
.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string()))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.map(|v| Inputs::Raw(v))
|
||||
.map(Inputs::Raw)
|
||||
}
|
||||
Err(_) => Err(String::from("???")),
|
||||
},
|
||||
|
|
|
@ -174,7 +174,7 @@ fn cli_generate_proof<T: Field, S: Scheme<T>, B: Backend<T, S>>(
|
|||
.write(proof.as_bytes())
|
||||
.map_err(|why| format!("Couldn't write to {}: {}", proof_path.display(), why))?;
|
||||
|
||||
println!("Proof:\n{}", format!("{}", proof));
|
||||
println!("Proof:\n{}", proof);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
pub mod check;
|
||||
pub mod compile;
|
||||
pub mod compute_witness;
|
||||
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
|
||||
pub mod export_verifier;
|
||||
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
|
||||
pub mod generate_proof;
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import "utils/casts/u32_to_field" as to_field
|
||||
|
||||
// Binomial Coeffizient, n!/(k!*(n-k)!).
|
||||
def fac(field x) -> field:
|
||||
field f = 1
|
||||
field counter = 0
|
||||
for field i in 1..100 do
|
||||
f = if counter == x then f else f * i fi
|
||||
for u32 i in 1..100 do
|
||||
f = if counter == x then f else f * to_field(i) fi
|
||||
counter = if counter == x then counter else counter + 1 fi
|
||||
endfor
|
||||
return f
|
||||
|
|
|
@ -3,7 +3,7 @@ extern crate serde_json;
|
|||
|
||||
#[cfg(test)]
|
||||
mod integration {
|
||||
use assert_cli;
|
||||
|
||||
use serde_json::from_reader;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
|
@ -147,11 +147,11 @@ mod integration {
|
|||
.map_err(|why| why.to_string())
|
||||
.unwrap();
|
||||
|
||||
let signature = abi.signature().clone();
|
||||
let signature = abi.signature();
|
||||
|
||||
let inputs_abi: zokrates_abi::Inputs<zokrates_field::Bn128Field> =
|
||||
parse_strict(&json_input_str, signature.inputs)
|
||||
.map(|parsed| zokrates_abi::Inputs::Abi(parsed))
|
||||
.map(zokrates_abi::Inputs::Abi)
|
||||
.map_err(|why| why.to_string())
|
||||
.unwrap();
|
||||
let inputs_raw: Vec<_> = inputs_abi
|
||||
|
@ -169,7 +169,7 @@ mod integration {
|
|||
inline_witness_path.to_str().unwrap(),
|
||||
];
|
||||
|
||||
if inputs_raw.len() > 0 {
|
||||
if !inputs_raw.is_empty() {
|
||||
compute_inline.push("-a");
|
||||
|
||||
for arg in &inputs_raw {
|
||||
|
@ -202,7 +202,7 @@ mod integration {
|
|||
|
||||
assert_eq!(inline_witness, witness);
|
||||
|
||||
for line in expected_witness.as_str().split("\n") {
|
||||
for line in expected_witness.as_str().split('\n') {
|
||||
assert!(
|
||||
witness.contains(line),
|
||||
"Witness generation failed for {}\n\nLine \"{}\" not found in witness",
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::absy;
|
||||
use crate::imports;
|
||||
|
||||
use num::ToPrimitive;
|
||||
use num_bigint::BigUint;
|
||||
use zokrates_pest_ast as pest;
|
||||
|
||||
|
@ -10,14 +9,14 @@ impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
|
|||
absy::Module::with_symbols(
|
||||
prog.structs
|
||||
.into_iter()
|
||||
.map(|t| absy::SymbolDeclarationNode::from(t))
|
||||
.map(absy::SymbolDeclarationNode::from)
|
||||
.chain(
|
||||
prog.functions
|
||||
.into_iter()
|
||||
.map(|f| absy::SymbolDeclarationNode::from(f)),
|
||||
.map(absy::SymbolDeclarationNode::from),
|
||||
),
|
||||
)
|
||||
.imports(prog.imports.into_iter().map(|i| absy::ImportNode::from(i)))
|
||||
.imports(prog.imports.into_iter().map(absy::ImportNode::from))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -31,17 +30,16 @@ impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportNode<'ast> {
|
|||
.alias(import.alias.map(|a| a.span.as_str()))
|
||||
.span(import.span)
|
||||
}
|
||||
pest::ImportDirective::From(import) => imports::Import::new(
|
||||
Some(import.symbol.span.as_str()),
|
||||
std::path::Path::new(import.source.span.as_str()),
|
||||
)
|
||||
.alias(
|
||||
import
|
||||
.alias
|
||||
.map(|a| a.span.as_str())
|
||||
.or(Some(import.symbol.span.as_str())),
|
||||
)
|
||||
.span(import.span),
|
||||
pest::ImportDirective::From(import) => {
|
||||
let symbol_str = import.symbol.span.as_str();
|
||||
|
||||
imports::Import::new(
|
||||
Some(import.symbol.span.as_str()),
|
||||
std::path::Path::new(import.source.span.as_str()),
|
||||
)
|
||||
.alias(import.alias.map(|a| a.span.as_str()).or(Some(symbol_str)))
|
||||
.span(import.span)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -58,7 +56,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
|
|||
fields: definition
|
||||
.fields
|
||||
.into_iter()
|
||||
.map(|f| absy::StructDefinitionFieldNode::from(f))
|
||||
.map(absy::StructDefinitionFieldNode::from)
|
||||
.collect(),
|
||||
}
|
||||
.span(span.clone());
|
||||
|
@ -72,7 +70,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
|
|||
}
|
||||
|
||||
impl<'ast> From<pest::StructField<'ast>> for absy::StructDefinitionFieldNode<'ast> {
|
||||
fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode {
|
||||
fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode<'ast> {
|
||||
use crate::absy::NodeValue;
|
||||
|
||||
let span = field.span;
|
||||
|
@ -92,6 +90,13 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
|
|||
let span = function.span;
|
||||
|
||||
let signature = absy::UnresolvedSignature::new()
|
||||
.generics(
|
||||
function
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(absy::ConstantGenericNode::from)
|
||||
.collect(),
|
||||
)
|
||||
.inputs(
|
||||
function
|
||||
.parameters
|
||||
|
@ -105,7 +110,7 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
|
|||
.returns
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|r| absy::UnresolvedTypeNode::from(r))
|
||||
.map(absy::UnresolvedTypeNode::from)
|
||||
.collect(),
|
||||
);
|
||||
|
||||
|
@ -115,12 +120,12 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
|
|||
arguments: function
|
||||
.parameters
|
||||
.into_iter()
|
||||
.map(|a| absy::ParameterNode::from(a))
|
||||
.map(absy::ParameterNode::from)
|
||||
.collect(),
|
||||
statements: function
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| statements_from_statement(s))
|
||||
.flat_map(statements_from_statement)
|
||||
.collect(),
|
||||
signature,
|
||||
}
|
||||
|
@ -134,6 +139,16 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::IdentifierExpression<'ast>> for absy::ConstantGenericNode<'ast> {
|
||||
fn from(g: pest::IdentifierExpression<'ast>) -> absy::ConstantGenericNode<'ast> {
|
||||
use absy::NodeValue;
|
||||
|
||||
let name = g.span.as_str();
|
||||
|
||||
name.span(g.span)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::Parameter<'ast>> for absy::ParameterNode<'ast> {
|
||||
fn from(param: pest::Parameter<'ast>) -> absy::ParameterNode<'ast> {
|
||||
use crate::absy::NodeValue;
|
||||
|
@ -247,7 +262,7 @@ impl<'ast> From<pest::ReturnStatement<'ast>> for absy::StatementNode<'ast> {
|
|||
expressions: statement
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|e| absy::ExpressionNode::from(e))
|
||||
.map(absy::ExpressionNode::from)
|
||||
.collect(),
|
||||
}
|
||||
.span(statement.span.clone()),
|
||||
|
@ -275,7 +290,7 @@ impl<'ast> From<pest::IterationStatement<'ast>> for absy::StatementNode<'ast> {
|
|||
let statements: Vec<absy::StatementNode<'ast>> = statement
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| statements_from_statement(s))
|
||||
.flat_map(statements_from_statement)
|
||||
.collect();
|
||||
|
||||
let var = absy::Variable::new(index, ty).span(statement.index.span);
|
||||
|
@ -289,7 +304,7 @@ impl<'ast> From<pest::Expression<'ast>> for absy::ExpressionNode<'ast> {
|
|||
match expression {
|
||||
pest::Expression::Binary(e) => absy::ExpressionNode::from(e),
|
||||
pest::Expression::Ternary(e) => absy::ExpressionNode::from(e),
|
||||
pest::Expression::Constant(e) => absy::ExpressionNode::from(e),
|
||||
pest::Expression::Literal(e) => absy::ExpressionNode::from(e),
|
||||
pest::Expression::Identifier(e) => absy::ExpressionNode::from(e),
|
||||
pest::Expression::Postfix(e) => absy::ExpressionNode::from(e),
|
||||
pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e),
|
||||
|
@ -458,7 +473,7 @@ impl<'ast> From<pest::InlineArrayExpression<'ast>> for absy::ExpressionNode<'ast
|
|||
array
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|e| absy::SpreadOrExpression::from(e))
|
||||
.map(absy::SpreadOrExpression::from)
|
||||
.collect(),
|
||||
)
|
||||
.span(array.span)
|
||||
|
@ -489,13 +504,8 @@ impl<'ast> From<pest::ArrayInitializerExpression<'ast>> for absy::ExpressionNode
|
|||
use crate::absy::NodeValue;
|
||||
|
||||
let value = absy::ExpressionNode::from(*initializer.value);
|
||||
let count: absy::ExpressionNode<'ast> = absy::ExpressionNode::from(initializer.count);
|
||||
let count = match count.value {
|
||||
absy::Expression::FieldConstant(v) => v.to_usize().unwrap(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
absy::Expression::InlineArray(vec![absy::SpreadOrExpression::Expression(value); count])
|
||||
.span(initializer.span)
|
||||
let count = absy::ExpressionNode::from(*initializer.count);
|
||||
absy::Expression::ArrayInitializer(box value, box count).span(initializer.span)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -529,9 +539,25 @@ impl<'ast> From<pest::PostfixExpression<'ast>> for absy::ExpressionNode<'ast> {
|
|||
pest::Access::Call(a) => match acc.value {
|
||||
absy::Expression::Identifier(_) => absy::Expression::FunctionCall(
|
||||
&id_str,
|
||||
a.expressions
|
||||
a.explicit_generics.map(|explicit_generics| {
|
||||
explicit_generics
|
||||
.values
|
||||
.into_iter()
|
||||
.map(|i| match i {
|
||||
pest::ConstantGenericValue::Underscore(_) => None,
|
||||
pest::ConstantGenericValue::Value(v) => {
|
||||
Some(absy::ExpressionNode::from(v))
|
||||
}
|
||||
pest::ConstantGenericValue::Identifier(i) => {
|
||||
Some(absy::Expression::Identifier(i.span.as_str()).span(i.span))
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
a.arguments
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|e| absy::ExpressionNode::from(e))
|
||||
.map(absy::ExpressionNode::from)
|
||||
.collect(),
|
||||
),
|
||||
e => unimplemented!("only identifiers are callable, found \"{}\"", e),
|
||||
|
@ -548,29 +574,63 @@ impl<'ast> From<pest::PostfixExpression<'ast>> for absy::ExpressionNode<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::ConstantExpression<'ast>> for absy::ExpressionNode<'ast> {
|
||||
fn from(expression: pest::ConstantExpression<'ast>) -> absy::ExpressionNode<'ast> {
|
||||
impl<'ast> From<pest::DecimalLiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
|
||||
fn from(expression: pest::DecimalLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
|
||||
use crate::absy::NodeValue;
|
||||
|
||||
match expression.suffix {
|
||||
Some(suffix) => match suffix {
|
||||
pest::DecimalSuffix::Field(_) => absy::Expression::FieldConstant(
|
||||
BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(),
|
||||
),
|
||||
pest::DecimalSuffix::U32(_) => absy::Expression::U32Constant(
|
||||
u32::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
|
||||
),
|
||||
pest::DecimalSuffix::U16(_) => absy::Expression::U16Constant(
|
||||
u16::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
|
||||
),
|
||||
pest::DecimalSuffix::U8(_) => absy::Expression::U8Constant(
|
||||
u8::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
|
||||
),
|
||||
}
|
||||
.span(expression.span),
|
||||
None => absy::Expression::IntConstant(
|
||||
BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(),
|
||||
)
|
||||
.span(expression.span),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::HexLiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
|
||||
fn from(expression: pest::HexLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
|
||||
use crate::absy::NodeValue;
|
||||
|
||||
match expression.value {
|
||||
pest::HexNumberExpression::U32(e) => {
|
||||
absy::Expression::U32Constant(u32::from_str_radix(&e.span.as_str(), 16).unwrap())
|
||||
}
|
||||
pest::HexNumberExpression::U16(e) => {
|
||||
absy::Expression::U16Constant(u16::from_str_radix(&e.span.as_str(), 16).unwrap())
|
||||
}
|
||||
pest::HexNumberExpression::U8(e) => {
|
||||
absy::Expression::U8Constant(u8::from_str_radix(&e.span.as_str(), 16).unwrap())
|
||||
}
|
||||
}
|
||||
.span(expression.span)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::LiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
|
||||
fn from(expression: pest::LiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
|
||||
use crate::absy::NodeValue;
|
||||
|
||||
match expression {
|
||||
pest::ConstantExpression::BooleanLiteral(c) => {
|
||||
pest::LiteralExpression::BooleanLiteral(c) => {
|
||||
absy::Expression::BooleanConstant(c.value.parse().unwrap()).span(c.span)
|
||||
}
|
||||
pest::ConstantExpression::DecimalNumber(n) => absy::Expression::FieldConstant(
|
||||
BigUint::parse_bytes(&n.value.as_bytes(), 10).unwrap(),
|
||||
)
|
||||
.span(n.span),
|
||||
pest::ConstantExpression::U8(n) => absy::Expression::U8Constant(
|
||||
u8::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
|
||||
)
|
||||
.span(n.span),
|
||||
pest::ConstantExpression::U16(n) => absy::Expression::U16Constant(
|
||||
u16::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
|
||||
)
|
||||
.span(n.span),
|
||||
pest::ConstantExpression::U32(n) => absy::Expression::U32Constant(
|
||||
u32::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
|
||||
)
|
||||
.span(n.span),
|
||||
pest::LiteralExpression::DecimalLiteral(n) => absy::ExpressionNode::from(n),
|
||||
pest::LiteralExpression::HexLiteral(n) => absy::ExpressionNode::from(n),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -611,8 +671,8 @@ impl<'ast> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
|
||||
fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode {
|
||||
impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
|
||||
fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode<'ast> {
|
||||
use crate::absy::types::UnresolvedType;
|
||||
use crate::absy::NodeValue;
|
||||
|
||||
|
@ -642,21 +702,7 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
|
|||
|
||||
t.dimensions
|
||||
.into_iter()
|
||||
.map(|s| match s {
|
||||
pest::Expression::Constant(c) => match c {
|
||||
pest::ConstantExpression::DecimalNumber(n) => {
|
||||
str::parse::<usize>(&n.value).unwrap()
|
||||
}
|
||||
_ => unimplemented!(
|
||||
"Array size should be a decimal number, found {}",
|
||||
c.span().as_str()
|
||||
),
|
||||
},
|
||||
e => unimplemented!(
|
||||
"Array size should be constant, found {}",
|
||||
e.span().as_str()
|
||||
),
|
||||
})
|
||||
.map(absy::ExpressionNode::from)
|
||||
.rev()
|
||||
.fold(None, |acc, s| match acc {
|
||||
None => Some(UnresolvedType::array(inner_type.clone(), s)),
|
||||
|
@ -690,10 +736,9 @@ mod tests {
|
|||
arguments: vec![],
|
||||
statements: vec![absy::Statement::Return(
|
||||
absy::ExpressionList {
|
||||
expressions: vec![absy::Expression::FieldConstant(BigUint::from(
|
||||
42u32,
|
||||
))
|
||||
.into()],
|
||||
expressions: vec![
|
||||
absy::Expression::IntConstant(42usize.into()).into()
|
||||
],
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
|
@ -771,10 +816,9 @@ mod tests {
|
|||
],
|
||||
statements: vec![absy::Statement::Return(
|
||||
absy::ExpressionList {
|
||||
expressions: vec![absy::Expression::FieldConstant(BigUint::from(
|
||||
42u32,
|
||||
))
|
||||
.into()],
|
||||
expressions: vec![
|
||||
absy::Expression::IntConstant(42usize.into()).into()
|
||||
],
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
|
@ -800,7 +844,7 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
/// Helper method to generate the ast for `def main(private {ty} a): return` which we use to check ty
|
||||
fn wrap(ty: UnresolvedType) -> absy::Module<'static> {
|
||||
fn wrap(ty: UnresolvedType<'static>) -> absy::Module<'static> {
|
||||
absy::Module {
|
||||
symbols: vec![absy::SymbolDeclaration {
|
||||
id: "main",
|
||||
|
@ -834,21 +878,31 @@ mod tests {
|
|||
("bool", UnresolvedType::Boolean),
|
||||
(
|
||||
"field[2]",
|
||||
UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 2),
|
||||
),
|
||||
(
|
||||
"field[2][3]",
|
||||
UnresolvedType::Array(
|
||||
box UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 3)
|
||||
.mock(),
|
||||
2,
|
||||
absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::FieldElement.mock(),
|
||||
absy::Expression::IntConstant(2usize.into()).mock(),
|
||||
),
|
||||
),
|
||||
(
|
||||
"bool[2][3]",
|
||||
UnresolvedType::Array(
|
||||
box UnresolvedType::Array(box UnresolvedType::Boolean.mock(), 3).mock(),
|
||||
2,
|
||||
"field[2][3]",
|
||||
absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::FieldElement.mock(),
|
||||
absy::Expression::IntConstant(3usize.into()).mock(),
|
||||
)
|
||||
.mock(),
|
||||
absy::Expression::IntConstant(2usize.into()).mock(),
|
||||
),
|
||||
),
|
||||
(
|
||||
"bool[2][3u32]",
|
||||
absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::Array(
|
||||
box absy::UnresolvedType::Boolean.mock(),
|
||||
absy::Expression::U32Constant(3u32).mock(),
|
||||
)
|
||||
.mock(),
|
||||
absy::Expression::IntConstant(2usize.into()).mock(),
|
||||
),
|
||||
),
|
||||
];
|
||||
|
@ -893,13 +947,13 @@ mod tests {
|
|||
// we basically accept `()?[]*` : an optional call at first, then only array accesses
|
||||
|
||||
let vectors = vec![
|
||||
("a", absy::Expression::Identifier("a").into()),
|
||||
("a", absy::Expression::Identifier("a")),
|
||||
(
|
||||
"a[3]",
|
||||
absy::Expression::Select(
|
||||
box absy::Expression::Identifier("a").into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(BigUint::from(3u32)).into(),
|
||||
absy::Expression::IntConstant(3usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -910,13 +964,13 @@ mod tests {
|
|||
box absy::Expression::Select(
|
||||
box absy::Expression::Identifier("a").into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(BigUint::from(3u32)).into(),
|
||||
absy::Expression::IntConstant(3usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(BigUint::from(4u32)).into(),
|
||||
absy::Expression::IntConstant(4usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -926,11 +980,12 @@ mod tests {
|
|||
absy::Expression::Select(
|
||||
box absy::Expression::FunctionCall(
|
||||
"a",
|
||||
vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()],
|
||||
None,
|
||||
vec![absy::Expression::IntConstant(3usize.into()).into()],
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(BigUint::from(4u32)).into(),
|
||||
absy::Expression::IntConstant(4usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -941,17 +996,18 @@ mod tests {
|
|||
box absy::Expression::Select(
|
||||
box absy::Expression::FunctionCall(
|
||||
"a",
|
||||
vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()],
|
||||
None,
|
||||
vec![absy::Expression::IntConstant(3usize.into()).into()],
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(BigUint::from(4u32)).into(),
|
||||
absy::Expression::IntConstant(4usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
.into(),
|
||||
box absy::RangeOrExpression::Expression(
|
||||
absy::Expression::FieldConstant(BigUint::from(5u32)).into(),
|
||||
absy::Expression::IntConstant(5usize.into()).into(),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
|
@ -993,7 +1049,7 @@ mod tests {
|
|||
// For different definitions, we generate declarations
|
||||
// Case 1: `id = expr` where `expr` is not a function call
|
||||
// This is a simple assignment, doesn't implicitely declare a variable
|
||||
// A `Definition` is generatedm and no `Declaration`s
|
||||
// A `Definition` is generated and no `Declaration`s
|
||||
|
||||
let definition = pest::DefinitionStatement {
|
||||
lhs: vec![pest::OptionallyTypedAssignee {
|
||||
|
@ -1008,9 +1064,12 @@ mod tests {
|
|||
},
|
||||
span: span.clone(),
|
||||
}],
|
||||
expression: pest::Expression::Constant(pest::ConstantExpression::DecimalNumber(
|
||||
pest::DecimalNumberExpression {
|
||||
value: String::from("42"),
|
||||
expression: pest::Expression::Literal(pest::LiteralExpression::DecimalLiteral(
|
||||
pest::DecimalLiteralExpression {
|
||||
value: pest::DecimalNumber {
|
||||
span: Span::new(&"1", 0, 1).unwrap(),
|
||||
},
|
||||
suffix: None,
|
||||
span: span.clone(),
|
||||
},
|
||||
)),
|
||||
|
@ -1049,7 +1108,11 @@ mod tests {
|
|||
span: span.clone(),
|
||||
},
|
||||
accesses: vec![pest::Access::Call(pest::CallAccess {
|
||||
expressions: vec![],
|
||||
explicit_generics: None,
|
||||
arguments: pest::Arguments {
|
||||
expressions: vec![],
|
||||
span: span.clone(),
|
||||
},
|
||||
span: span.clone(),
|
||||
})],
|
||||
span: span.clone(),
|
||||
|
@ -1106,7 +1169,11 @@ mod tests {
|
|||
span: span.clone(),
|
||||
},
|
||||
accesses: vec![pest::Access::Call(pest::CallAccess {
|
||||
expressions: vec![],
|
||||
explicit_generics: None,
|
||||
arguments: pest::Arguments {
|
||||
expressions: vec![],
|
||||
span: span.clone(),
|
||||
},
|
||||
span: span.clone(),
|
||||
})],
|
||||
span: span.clone(),
|
||||
|
|
|
@ -105,7 +105,7 @@ impl<'ast> Module<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type UnresolvedTypeNode = Node<UnresolvedType>;
|
||||
pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
|
||||
|
||||
/// A struct type definition
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
@ -133,7 +133,7 @@ pub type StructDefinitionNode<'ast> = Node<StructDefinition<'ast>>;
|
|||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct StructDefinitionField<'ast> {
|
||||
pub id: Identifier<'ast>,
|
||||
pub ty: UnresolvedTypeNode,
|
||||
pub ty: UnresolvedTypeNode<'ast>,
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for StructDefinitionField<'ast> {
|
||||
|
@ -216,6 +216,8 @@ impl<'ast> fmt::Debug for Module<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
pub type ConstantGenericNode<'ast> = Node<Identifier<'ast>>;
|
||||
|
||||
/// A function defined locally
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct Function<'ast> {
|
||||
|
@ -224,13 +226,26 @@ pub struct Function<'ast> {
|
|||
/// Vector of statements that are executed when running the function
|
||||
pub statements: Vec<StatementNode<'ast>>,
|
||||
/// function signature
|
||||
pub signature: UnresolvedSignature,
|
||||
pub signature: UnresolvedSignature<'ast>,
|
||||
}
|
||||
|
||||
pub type FunctionNode<'ast> = Node<Function<'ast>>;
|
||||
|
||||
impl<'ast> fmt::Display for Function<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if !self.signature.generics.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"<{}>",
|
||||
self.signature
|
||||
.generics
|
||||
.iter()
|
||||
.map(|g| g.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
}
|
||||
|
||||
write!(
|
||||
f,
|
||||
"({}):\n{}",
|
||||
|
@ -294,6 +309,7 @@ impl<'ast> fmt::Display for Assignee<'ast> {
|
|||
}
|
||||
|
||||
/// A statement in a `Function`
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum Statement<'ast> {
|
||||
Return(ExpressionListNode<'ast>),
|
||||
|
@ -319,9 +335,9 @@ impl<'ast> fmt::Display for Statement<'ast> {
|
|||
Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
|
||||
Statement::Assertion(ref e) => write!(f, "assert({})", e),
|
||||
Statement::For(ref var, ref start, ref stop, ref list) => {
|
||||
write!(f, "for {} in {}..{} do\n", var, start, stop)?;
|
||||
writeln!(f, "for {} in {}..{} do", var, start, stop)?;
|
||||
for l in list {
|
||||
write!(f, "\t\t{}\n", l)?;
|
||||
writeln!(f, "\t\t{}", l)?;
|
||||
}
|
||||
write!(f, "\tendfor")
|
||||
}
|
||||
|
@ -348,9 +364,9 @@ impl<'ast> fmt::Debug for Statement<'ast> {
|
|||
}
|
||||
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
|
||||
Statement::For(ref var, ref start, ref stop, ref list) => {
|
||||
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?;
|
||||
writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?;
|
||||
for l in list {
|
||||
write!(f, "\t\t{:?}\n", l)?;
|
||||
writeln!(f, "\t\t{:?}", l)?;
|
||||
}
|
||||
write!(f, "\tendfor")
|
||||
}
|
||||
|
@ -454,11 +470,11 @@ impl<'ast> fmt::Display for Range<'ast> {
|
|||
self.from
|
||||
.as_ref()
|
||||
.map(|e| e.to_string())
|
||||
.unwrap_or("".to_string()),
|
||||
.unwrap_or_else(|| "".to_string()),
|
||||
self.to
|
||||
.as_ref()
|
||||
.map(|e| e.to_string())
|
||||
.unwrap_or("".to_string())
|
||||
.unwrap_or_else(|| "".to_string())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -472,6 +488,7 @@ impl<'ast> fmt::Debug for Range<'ast> {
|
|||
/// An expression
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub enum Expression<'ast> {
|
||||
IntConstant(BigUint),
|
||||
FieldConstant(BigUint),
|
||||
BooleanConstant(bool),
|
||||
U8Constant(u8),
|
||||
|
@ -491,7 +508,11 @@ pub enum Expression<'ast> {
|
|||
Box<ExpressionNode<'ast>>,
|
||||
Box<ExpressionNode<'ast>>,
|
||||
),
|
||||
FunctionCall(FunctionIdentifier<'ast>, Vec<ExpressionNode<'ast>>),
|
||||
FunctionCall(
|
||||
FunctionIdentifier<'ast>,
|
||||
Option<Vec<Option<ExpressionNode<'ast>>>>,
|
||||
Vec<ExpressionNode<'ast>>,
|
||||
),
|
||||
Lt(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
Le(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
Eq(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
|
@ -500,6 +521,7 @@ pub enum Expression<'ast> {
|
|||
And(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
Not(Box<ExpressionNode<'ast>>),
|
||||
InlineArray(Vec<SpreadOrExpression<'ast>>),
|
||||
ArrayInitializer(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
|
||||
InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast>)>),
|
||||
Select(Box<ExpressionNode<'ast>>, Box<RangeOrExpression<'ast>>),
|
||||
Member(Box<ExpressionNode<'ast>>, Box<Identifier<'ast>>),
|
||||
|
@ -520,6 +542,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
|
|||
Expression::U8Constant(ref i) => write!(f, "{}", i),
|
||||
Expression::U16Constant(ref i) => write!(f, "{}", i),
|
||||
Expression::U32Constant(ref i) => write!(f, "{}", i),
|
||||
Expression::IntConstant(ref i) => write!(f, "{}", i),
|
||||
Expression::Identifier(ref var) => write!(f, "{}", var),
|
||||
Expression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
|
||||
Expression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
|
||||
|
@ -535,8 +558,21 @@ impl<'ast> fmt::Display for Expression<'ast> {
|
|||
"if {} then {} else {} fi",
|
||||
condition, consequent, alternative
|
||||
),
|
||||
Expression::FunctionCall(ref i, ref p) => {
|
||||
write!(f, "{}(", i,)?;
|
||||
Expression::FunctionCall(ref i, ref g, ref p) => {
|
||||
if let Some(g) = g {
|
||||
write!(
|
||||
f,
|
||||
"::<{}>",
|
||||
g.iter()
|
||||
.map(|g| g
|
||||
.as_ref()
|
||||
.map(|g| g.to_string())
|
||||
.unwrap_or_else(|| "_".into()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
)?;
|
||||
}
|
||||
write!(f, "{}(", i)?;
|
||||
for (i, param) in p.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < p.len() - 1 {
|
||||
|
@ -562,6 +598,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
|
|||
}
|
||||
write!(f, "]")
|
||||
}
|
||||
Expression::ArrayInitializer(ref e, ref count) => write!(f, "[{}; {}]", e, count),
|
||||
Expression::InlineStruct(ref id, ref members) => {
|
||||
write!(f, "{} {{", id)?;
|
||||
for (i, (member_id, e)) in members.iter().enumerate() {
|
||||
|
@ -587,10 +624,11 @@ impl<'ast> fmt::Display for Expression<'ast> {
|
|||
impl<'ast> fmt::Debug for Expression<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
Expression::U8Constant(ref i) => write!(f, "{:x}", i),
|
||||
Expression::U16Constant(ref i) => write!(f, "{:x}", i),
|
||||
Expression::U32Constant(ref i) => write!(f, "{:x}", i),
|
||||
Expression::FieldConstant(ref i) => write!(f, "Num({:?})", i),
|
||||
Expression::U8Constant(ref i) => write!(f, "U8({:x})", i),
|
||||
Expression::U16Constant(ref i) => write!(f, "U16({:x})", i),
|
||||
Expression::U32Constant(ref i) => write!(f, "U32({:x})", i),
|
||||
Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i),
|
||||
Expression::IntConstant(ref i) => write!(f, "Int({:?})", i),
|
||||
Expression::Identifier(ref var) => write!(f, "Ide({})", var),
|
||||
Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs),
|
||||
Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
|
||||
|
@ -606,8 +644,8 @@ impl<'ast> fmt::Debug for Expression<'ast> {
|
|||
"IfElse({:?}, {:?}, {:?})",
|
||||
condition, consequent, alternative
|
||||
),
|
||||
Expression::FunctionCall(ref i, ref p) => {
|
||||
write!(f, "FunctionCall({:?}, (", i)?;
|
||||
Expression::FunctionCall(ref g, ref i, ref p) => {
|
||||
write!(f, "FunctionCall({:?}, {:?}, (", g, i)?;
|
||||
f.debug_list().entries(p.iter()).finish()?;
|
||||
write!(f, ")")
|
||||
}
|
||||
|
@ -623,6 +661,9 @@ impl<'ast> fmt::Debug for Expression<'ast> {
|
|||
f.debug_list().entries(exprs.iter()).finish()?;
|
||||
write!(f, "]")
|
||||
}
|
||||
Expression::ArrayInitializer(ref e, ref count) => {
|
||||
write!(f, "ArrayInitializer({:?}, {:?})", e, count)
|
||||
}
|
||||
Expression::InlineStruct(ref id, ref members) => {
|
||||
write!(f, "InlineStruct({:?}, [", id)?;
|
||||
f.debug_list().entries(members.iter()).finish()?;
|
||||
|
@ -645,21 +686,13 @@ impl<'ast> fmt::Debug for Expression<'ast> {
|
|||
}
|
||||
|
||||
/// A list of expressions, used in return statements
|
||||
#[derive(Clone, PartialEq)]
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct ExpressionList<'ast> {
|
||||
pub expressions: Vec<ExpressionNode<'ast>>,
|
||||
}
|
||||
|
||||
pub type ExpressionListNode<'ast> = Node<ExpressionList<'ast>>;
|
||||
|
||||
impl<'ast> ExpressionList<'ast> {
|
||||
pub fn new() -> ExpressionList<'ast> {
|
||||
ExpressionList {
|
||||
expressions: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast> fmt::Display for ExpressionList<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
for (i, param) in self.expressions.iter().enumerate() {
|
||||
|
|
|
@ -81,7 +81,7 @@ impl<'ast> NodeValue for ExpressionList<'ast> {}
|
|||
impl<'ast> NodeValue for Assignee<'ast> {}
|
||||
impl<'ast> NodeValue for Statement<'ast> {}
|
||||
impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
|
||||
impl NodeValue for UnresolvedType {}
|
||||
impl<'ast> NodeValue for UnresolvedType<'ast> {}
|
||||
impl<'ast> NodeValue for StructDefinition<'ast> {}
|
||||
impl<'ast> NodeValue for StructDefinitionField<'ast> {}
|
||||
impl<'ast> NodeValue for Function<'ast> {}
|
||||
|
@ -92,6 +92,7 @@ impl<'ast> NodeValue for Parameter<'ast> {}
|
|||
impl<'ast> NodeValue for Import<'ast> {}
|
||||
impl<'ast> NodeValue for Spread<'ast> {}
|
||||
impl<'ast> NodeValue for Range<'ast> {}
|
||||
impl<'ast> NodeValue for Identifier<'ast> {}
|
||||
|
||||
impl<T: PartialEq> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Node<T>) -> bool {
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::absy::ExpressionNode;
|
||||
use crate::absy::UnresolvedTypeNode;
|
||||
use std::fmt;
|
||||
|
||||
|
@ -8,15 +9,15 @@ pub type MemberId = String;
|
|||
pub type UserTypeId = String;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum UnresolvedType {
|
||||
pub enum UnresolvedType<'ast> {
|
||||
FieldElement,
|
||||
Boolean,
|
||||
Uint(usize),
|
||||
Array(Box<UnresolvedTypeNode>, usize),
|
||||
Array(Box<UnresolvedTypeNode<'ast>>, ExpressionNode<'ast>),
|
||||
User(UserTypeId),
|
||||
}
|
||||
|
||||
impl fmt::Display for UnresolvedType {
|
||||
impl<'ast> fmt::Display for UnresolvedType<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
UnresolvedType::FieldElement => write!(f, "field"),
|
||||
|
@ -28,8 +29,8 @@ impl fmt::Display for UnresolvedType {
|
|||
}
|
||||
}
|
||||
|
||||
impl UnresolvedType {
|
||||
pub fn array(ty: UnresolvedTypeNode, size: usize) -> Self {
|
||||
impl<'ast> UnresolvedType<'ast> {
|
||||
pub fn array(ty: UnresolvedTypeNode<'ast>, size: ExpressionNode<'ast>) -> Self {
|
||||
UnresolvedType::Array(box ty, size)
|
||||
}
|
||||
}
|
||||
|
@ -39,17 +40,19 @@ pub type FunctionIdentifier<'ast> = &'ast str;
|
|||
pub use self::signature::UnresolvedSignature;
|
||||
|
||||
mod signature {
|
||||
use crate::absy::ConstantGenericNode;
|
||||
use std::fmt;
|
||||
|
||||
use crate::absy::UnresolvedTypeNode;
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
pub struct UnresolvedSignature {
|
||||
pub inputs: Vec<UnresolvedTypeNode>,
|
||||
pub outputs: Vec<UnresolvedTypeNode>,
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct UnresolvedSignature<'ast> {
|
||||
pub generics: Vec<ConstantGenericNode<'ast>>,
|
||||
pub inputs: Vec<UnresolvedTypeNode<'ast>>,
|
||||
pub outputs: Vec<UnresolvedTypeNode<'ast>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for UnresolvedSignature {
|
||||
impl<'ast> fmt::Debug for UnresolvedSignature<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
|
@ -59,7 +62,7 @@ mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for UnresolvedSignature {
|
||||
impl<'ast> fmt::Display for UnresolvedSignature<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "(")?;
|
||||
for (i, t) in self.inputs.iter().enumerate() {
|
||||
|
@ -79,20 +82,22 @@ mod signature {
|
|||
}
|
||||
}
|
||||
|
||||
impl UnresolvedSignature {
|
||||
pub fn new() -> UnresolvedSignature {
|
||||
UnresolvedSignature {
|
||||
inputs: vec![],
|
||||
outputs: vec![],
|
||||
}
|
||||
impl<'ast> UnresolvedSignature<'ast> {
|
||||
pub fn new() -> UnresolvedSignature<'ast> {
|
||||
UnresolvedSignature::default()
|
||||
}
|
||||
|
||||
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode>) -> Self {
|
||||
pub fn generics(mut self, generics: Vec<ConstantGenericNode<'ast>>) -> Self {
|
||||
self.generics = generics;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode<'ast>>) -> Self {
|
||||
self.inputs = inputs;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode>) -> Self {
|
||||
pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode<'ast>>) -> Self {
|
||||
self.outputs = outputs;
|
||||
self
|
||||
}
|
||||
|
|
|
@ -7,21 +7,21 @@ use crate::absy::Identifier;
|
|||
#[derive(Clone, PartialEq)]
|
||||
pub struct Variable<'ast> {
|
||||
pub id: Identifier<'ast>,
|
||||
pub _type: UnresolvedTypeNode,
|
||||
pub _type: UnresolvedTypeNode<'ast>,
|
||||
}
|
||||
|
||||
pub type VariableNode<'ast> = Node<Variable<'ast>>;
|
||||
|
||||
impl<'ast> Variable<'ast> {
|
||||
pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode) -> Variable<'ast> {
|
||||
pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode<'ast>) -> Variable<'ast> {
|
||||
Variable {
|
||||
id: id.into(),
|
||||
_type: t,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_type(&self) -> UnresolvedType {
|
||||
self._type.value.clone()
|
||||
pub fn get_type(&self) -> &UnresolvedType<'ast> {
|
||||
&self._type.value
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::imports::{self, Importer};
|
|||
use crate::ir;
|
||||
use crate::macros;
|
||||
use crate::semantics::{self, Checker};
|
||||
use crate::static_analysis;
|
||||
use crate::static_analysis::Analyse;
|
||||
use crate::typed_absy::abi::Abi;
|
||||
use crate::zir::ZirProgram;
|
||||
|
@ -55,6 +56,7 @@ pub enum CompileErrorInner {
|
|||
MacroError(macros::Error),
|
||||
SemanticError(semantics::ErrorInner),
|
||||
ReadError(io::Error),
|
||||
AnalysisError(static_analysis::Error),
|
||||
}
|
||||
|
||||
impl CompileErrorInner {
|
||||
|
@ -129,6 +131,12 @@ impl From<semantics::Error> for CompileError {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<static_analysis::Error> for CompileErrorInner {
|
||||
fn from(error: static_analysis::Error) -> Self {
|
||||
CompileErrorInner::AnalysisError(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for CompileErrorInner {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match *self {
|
||||
|
@ -137,6 +145,7 @@ impl fmt::Display for CompileErrorInner {
|
|||
CompileErrorInner::SemanticError(ref e) => write!(f, "{}", e),
|
||||
CompileErrorInner::ReadError(ref e) => write!(f, "{}", e),
|
||||
CompileErrorInner::ImportError(ref e) => write!(f, "{}", e),
|
||||
CompileErrorInner::AnalysisError(ref e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -179,7 +188,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
|
|||
})
|
||||
}
|
||||
|
||||
pub fn check<'ast, T: Field, E: Into<imports::Error>>(
|
||||
pub fn check<T: Field, E: Into<imports::Error>>(
|
||||
source: String,
|
||||
location: FilePath,
|
||||
resolver: Option<&dyn Resolver<E>>,
|
||||
|
@ -196,19 +205,18 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
|
|||
arena: &'ast Arena<String>,
|
||||
) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> {
|
||||
let source = arena.alloc(source);
|
||||
let compiled = compile_program::<T, E>(source, location.clone(), resolver, &arena)?;
|
||||
let compiled = compile_program::<T, E>(source, location, resolver, &arena)?;
|
||||
|
||||
// check semantics
|
||||
let typed_ast = Checker::check(compiled).map_err(|errors| {
|
||||
CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect())
|
||||
})?;
|
||||
let typed_ast = Checker::check(compiled)
|
||||
.map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?;
|
||||
|
||||
let abi = typed_ast.abi();
|
||||
let main_module = typed_ast.main.clone();
|
||||
|
||||
// analyse (unroll and constant propagation)
|
||||
let typed_ast = typed_ast.analyse();
|
||||
|
||||
Ok((typed_ast, abi))
|
||||
typed_ast
|
||||
.analyse()
|
||||
.map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)]))
|
||||
}
|
||||
|
||||
pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
|
||||
|
@ -244,7 +252,7 @@ pub fn compile_module<'ast, T: Field, E: Into<imports::Error>>(
|
|||
|
||||
let module_without_imports: Module = Module::from(ast);
|
||||
|
||||
Importer::new().apply_imports::<T, E>(
|
||||
Importer::apply_imports::<T, E>(
|
||||
module_without_imports,
|
||||
location.clone(),
|
||||
resolver,
|
||||
|
@ -380,17 +388,17 @@ struct Bar { field a }
|
|||
inputs: vec![AbiInput {
|
||||
name: "f".into(),
|
||||
public: true,
|
||||
ty: Type::Struct(StructType::new(
|
||||
ty: ConcreteType::Struct(ConcreteStructType::new(
|
||||
"foo".into(),
|
||||
"Foo".into(),
|
||||
vec![StructMember {
|
||||
vec![ConcreteStructMember {
|
||||
id: "b".into(),
|
||||
ty: box Type::Struct(StructType::new(
|
||||
ty: box ConcreteType::Struct(ConcreteStructType::new(
|
||||
"bar".into(),
|
||||
"Bar".into(),
|
||||
vec![StructMember {
|
||||
vec![ConcreteStructMember {
|
||||
id: "a".into(),
|
||||
ty: box Type::FieldElement
|
||||
ty: box ConcreteType::FieldElement
|
||||
}]
|
||||
))
|
||||
}]
|
||||
|
|
|
@ -3,7 +3,9 @@ use crate::flat_absy::{
|
|||
FlatVariable,
|
||||
};
|
||||
use crate::solvers::Solver;
|
||||
use crate::typed_absy::types::{FunctionKey, Signature, Type};
|
||||
use crate::typed_absy::types::{
|
||||
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use zokrates_field::{Bn128Field, Field};
|
||||
|
||||
|
@ -18,11 +20,12 @@ cfg_if::cfg_if! {
|
|||
|
||||
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
|
||||
/// the flattening step when it can be inlined.
|
||||
#[derive(Debug, Clone, PartialEq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
|
||||
pub enum FlatEmbed {
|
||||
U32ToField,
|
||||
#[cfg(feature = "bellman")]
|
||||
Sha256Round,
|
||||
Unpack(usize),
|
||||
Unpack,
|
||||
U8ToBits,
|
||||
U16ToBits,
|
||||
U32ToBits,
|
||||
|
@ -32,48 +35,88 @@ pub enum FlatEmbed {
|
|||
}
|
||||
|
||||
impl FlatEmbed {
|
||||
pub fn signature(&self) -> Signature {
|
||||
pub fn signature(&self) -> DeclarationSignature<'static> {
|
||||
match self {
|
||||
FlatEmbed::U32ToField => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(32)])
|
||||
.outputs(vec![DeclarationType::FieldElement]),
|
||||
FlatEmbed::Unpack => DeclarationSignature::new()
|
||||
.generics(vec![Some(Constant::Generic("N"))])
|
||||
.inputs(vec![DeclarationType::FieldElement])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
"N",
|
||||
))]),
|
||||
FlatEmbed::U8ToBits => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(8)])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
8usize,
|
||||
))]),
|
||||
FlatEmbed::U16ToBits => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(16)])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
16usize,
|
||||
))]),
|
||||
FlatEmbed::U32ToBits => DeclarationSignature::new()
|
||||
.inputs(vec![DeclarationType::uint(32)])
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
32usize,
|
||||
))]),
|
||||
FlatEmbed::U8FromBits => DeclarationSignature::new()
|
||||
.outputs(vec![DeclarationType::uint(8)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
8usize,
|
||||
))]),
|
||||
FlatEmbed::U16FromBits => DeclarationSignature::new()
|
||||
.outputs(vec![DeclarationType::uint(16)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
16usize,
|
||||
))]),
|
||||
FlatEmbed::U32FromBits => DeclarationSignature::new()
|
||||
.outputs(vec![DeclarationType::uint(32)])
|
||||
.inputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
32usize,
|
||||
))]),
|
||||
#[cfg(feature = "bellman")]
|
||||
FlatEmbed::Sha256Round => Signature::new()
|
||||
FlatEmbed::Sha256Round => DeclarationSignature::new()
|
||||
.inputs(vec![
|
||||
Type::array(Type::Boolean, 512),
|
||||
Type::array(Type::Boolean, 256),
|
||||
DeclarationType::array((DeclarationType::Boolean, 512usize)),
|
||||
DeclarationType::array((DeclarationType::Boolean, 256usize)),
|
||||
])
|
||||
.outputs(vec![Type::array(Type::Boolean, 256)]),
|
||||
FlatEmbed::Unpack(bitwidth) => Signature::new()
|
||||
.inputs(vec![Type::FieldElement])
|
||||
.outputs(vec![Type::array(Type::Boolean, *bitwidth)]),
|
||||
FlatEmbed::U8ToBits => Signature::new()
|
||||
.inputs(vec![Type::uint(8)])
|
||||
.outputs(vec![Type::array(Type::Boolean, 8)]),
|
||||
FlatEmbed::U16ToBits => Signature::new()
|
||||
.inputs(vec![Type::uint(16)])
|
||||
.outputs(vec![Type::array(Type::Boolean, 16)]),
|
||||
FlatEmbed::U32ToBits => Signature::new()
|
||||
.inputs(vec![Type::uint(32)])
|
||||
.outputs(vec![Type::array(Type::Boolean, 32)]),
|
||||
FlatEmbed::U8FromBits => Signature::new()
|
||||
.outputs(vec![Type::uint(8)])
|
||||
.inputs(vec![Type::array(Type::Boolean, 8)]),
|
||||
FlatEmbed::U16FromBits => Signature::new()
|
||||
.outputs(vec![Type::uint(16)])
|
||||
.inputs(vec![Type::array(Type::Boolean, 16)]),
|
||||
FlatEmbed::U32FromBits => Signature::new()
|
||||
.outputs(vec![Type::uint(32)])
|
||||
.inputs(vec![Type::array(Type::Boolean, 32)]),
|
||||
.outputs(vec![DeclarationType::array((
|
||||
DeclarationType::Boolean,
|
||||
256usize,
|
||||
))]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn key<T: Field>(&self) -> FunctionKey<'static> {
|
||||
FunctionKey::with_id(self.id()).signature(self.signature())
|
||||
pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
|
||||
let gen = self
|
||||
.signature()
|
||||
.generics
|
||||
.into_iter()
|
||||
.map(|c| match c.unwrap() {
|
||||
Constant::Generic(g) => g,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
assert_eq!(gen.len(), assignment.0.len());
|
||||
gen.map(|g| *assignment.0.get(&g).clone().unwrap() as u32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
FlatEmbed::U32ToField => "_U32_TO_FIELD",
|
||||
#[cfg(feature = "bellman")]
|
||||
FlatEmbed::Sha256Round => "_SHA256_ROUND",
|
||||
FlatEmbed::Unpack(_) => "_UNPACK",
|
||||
FlatEmbed::Unpack => "_UNPACK",
|
||||
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
||||
FlatEmbed::U16ToBits => "_U16_TO_BITS",
|
||||
FlatEmbed::U32ToBits => "_U32_TO_BITS",
|
||||
|
@ -84,11 +127,11 @@ impl FlatEmbed {
|
|||
}
|
||||
|
||||
/// Actually get the `FlatFunction` that this `FlatEmbed` represents
|
||||
pub fn synthetize<T: Field>(&self) -> FlatFunction<T> {
|
||||
pub fn synthetize<T: Field>(&self, generics: &[u32]) -> FlatFunction<T> {
|
||||
match self {
|
||||
#[cfg(feature = "bellman")]
|
||||
FlatEmbed::Sha256Round => sha256_round(),
|
||||
FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth),
|
||||
FlatEmbed::Unpack => unpack_to_bitwidth(generics[0] as usize),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +144,7 @@ fn flat_expression_from_vec<T: Field, E: Engine>(v: &[(usize, E::Fr)]) -> FlatEx
|
|||
match v.len() {
|
||||
0 => FlatExpression::Number(T::zero()),
|
||||
1 => {
|
||||
let (key, val) = v[0].clone();
|
||||
let (key, val) = v[0];
|
||||
let mut res: Vec<u8> = vec![];
|
||||
val.into_repr().write_le(&mut res).unwrap();
|
||||
FlatExpression::Mult(
|
||||
|
@ -152,7 +195,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
|
|||
let output_indices = output_indices.into_iter();
|
||||
let variable_count = r1cs.aux_count + 1; // auxiliary and ONE
|
||||
// indices of the sha256round constraint system variables
|
||||
let cs_indices = (0..variable_count).into_iter();
|
||||
let cs_indices = 0..variable_count;
|
||||
// indices of the arguments to the function
|
||||
// apply an offset of `variable_count` to get the indice of our dummy `input` argument
|
||||
let input_argument_indices = input_indices
|
||||
|
@ -180,7 +223,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
|
|||
);
|
||||
let input_binding_statements =
|
||||
// bind input and current_hash to inputs
|
||||
input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
|
||||
input_indices.chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
|
||||
FlatStatement::Condition(
|
||||
FlatVariable::new(cs_index).into(),
|
||||
FlatVariable::new(argument_index).into(),
|
||||
|
@ -194,7 +237,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
|
|||
.collect();
|
||||
// insert a directive to set the witness based on the bellman gadget and inputs
|
||||
let directive_statement = FlatStatement::Directive(FlatDirective {
|
||||
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
|
||||
outputs: cs_indices.map(FlatVariable::new).collect(),
|
||||
inputs: input_argument_indices
|
||||
.chain(current_hash_argument_indices)
|
||||
.map(|i| FlatVariable::new(i).into())
|
||||
|
@ -224,7 +267,7 @@ fn use_variable(
|
|||
) -> FlatVariable {
|
||||
let var = FlatVariable::new(*index);
|
||||
layout.insert(name, var);
|
||||
*index = *index + 1;
|
||||
*index += 1;
|
||||
var
|
||||
}
|
||||
|
||||
|
@ -255,7 +298,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
|
|||
|
||||
let directive_inputs = vec![FlatExpression::Identifier(use_variable(
|
||||
&mut layout,
|
||||
format!("i0"),
|
||||
"i0".into(),
|
||||
&mut counter,
|
||||
))];
|
||||
|
||||
|
@ -268,7 +311,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
|
|||
let outputs = directive_outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(_, o)| FlatExpression::Identifier(o.clone()))
|
||||
.map(|(_, o)| FlatExpression::Identifier(*o))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// o253, o252, ... o{253 - (bit_width - 1)} are bits
|
||||
|
@ -308,7 +351,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
|
|||
FlatStatement::Directive(FlatDirective {
|
||||
inputs: directive_inputs,
|
||||
outputs: directive_outputs,
|
||||
solver: solver,
|
||||
solver,
|
||||
}),
|
||||
);
|
||||
|
||||
|
@ -436,7 +479,7 @@ mod tests {
|
|||
private: vec![true; 768],
|
||||
};
|
||||
|
||||
let input = (0..512)
|
||||
let input: Vec<_> = (0..512)
|
||||
.map(|_| 0)
|
||||
.chain((0..256).map(|_| 1))
|
||||
.map(|i| Bn128Field::from(i))
|
||||
|
|
|
@ -41,7 +41,7 @@ impl FlatParameter {
|
|||
substitution: &HashMap<FlatVariable, FlatVariable>,
|
||||
) -> FlatParameter {
|
||||
FlatParameter {
|
||||
id: substitution.get(&self.id).unwrap().clone(),
|
||||
id: *substitution.get(&self.id).unwrap(),
|
||||
private: self.private,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ impl FlatVariable {
|
|||
Ok(FlatVariable::public(v))
|
||||
}
|
||||
None => {
|
||||
let mut private = s.split("_");
|
||||
let mut private = s.split('_');
|
||||
match private.nth(1) {
|
||||
Some(v) => {
|
||||
let v = v.parse().map_err(|_| s)?;
|
||||
|
|
|
@ -155,23 +155,13 @@ impl<T: Field> FlatStatement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, Debug)]
|
||||
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
|
||||
pub struct FlatDirective<T: Field> {
|
||||
pub inputs: Vec<FlatExpression<T>>,
|
||||
pub outputs: Vec<FlatVariable>,
|
||||
pub solver: Solver,
|
||||
}
|
||||
|
||||
impl<T: Field> PartialEq for FlatDirective<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.inputs.eq(&other.inputs)
|
||||
&& self.outputs.eq(&other.outputs)
|
||||
&& self.solver.eq(&other.solver)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Eq for FlatDirective<T> {}
|
||||
|
||||
impl<T: Field> FlatDirective<T> {
|
||||
pub fn new<E: Into<FlatExpression<T>>>(
|
||||
outputs: Vec<FlatVariable>,
|
||||
|
@ -249,12 +239,18 @@ impl<T: Field> FlatExpression<T> {
|
|||
FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => {
|
||||
x.is_linear() && y.is_linear()
|
||||
}
|
||||
FlatExpression::Mult(ref x, ref y) => match (x.clone(), y.clone()) {
|
||||
FlatExpression::Mult(ref x, ref y) => matches!(
|
||||
(x.clone(), y.clone()),
|
||||
(box FlatExpression::Number(_), box FlatExpression::Number(_))
|
||||
| (box FlatExpression::Number(_), box FlatExpression::Identifier(_))
|
||||
| (box FlatExpression::Identifier(_), box FlatExpression::Number(_)) => true,
|
||||
_ => false,
|
||||
},
|
||||
| (
|
||||
box FlatExpression::Number(_),
|
||||
box FlatExpression::Identifier(_)
|
||||
)
|
||||
| (
|
||||
box FlatExpression::Identifier(_),
|
||||
box FlatExpression::Number(_)
|
||||
)
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -42,7 +42,7 @@ impl fmt::Display for Error {
|
|||
let location = self
|
||||
.pos
|
||||
.map(|p| format!("{}", p.0))
|
||||
.unwrap_or("?".to_string());
|
||||
.unwrap_or_else(|| "?".to_string());
|
||||
write!(f, "{}\n\t{}", location, self.message)
|
||||
}
|
||||
}
|
||||
|
@ -125,15 +125,10 @@ impl<'ast> fmt::Debug for Import<'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct Importer {}
|
||||
pub struct Importer;
|
||||
|
||||
impl Importer {
|
||||
pub fn new() -> Importer {
|
||||
Importer {}
|
||||
}
|
||||
|
||||
pub fn apply_imports<'ast, T: Field, E: Into<Error>>(
|
||||
&self,
|
||||
destination: Module<'ast>,
|
||||
location: PathBuf,
|
||||
resolver: Option<&dyn Resolver<E>>,
|
||||
|
@ -179,7 +174,7 @@ impl Importer {
|
|||
symbols.push(
|
||||
SymbolDeclaration {
|
||||
id: &alias,
|
||||
symbol: Symbol::Flat(FlatEmbed::Unpack(T::get_required_bits())),
|
||||
symbol: Symbol::Flat(FlatEmbed::Unpack),
|
||||
}
|
||||
.start_end(pos.0, pos.1),
|
||||
);
|
||||
|
@ -267,13 +262,15 @@ impl Importer {
|
|||
let alias = import.alias.unwrap_or(
|
||||
std::path::Path::new(import.source)
|
||||
.file_stem()
|
||||
.ok_or(CompileErrors::from(
|
||||
CompileErrorInner::ImportError(Error::new(format!(
|
||||
"Could not determine alias for import {}",
|
||||
import.source.display()
|
||||
)))
|
||||
.in_file(&location),
|
||||
))?
|
||||
.ok_or_else(|| {
|
||||
CompileErrors::from(
|
||||
CompileErrorInner::ImportError(Error::new(format!(
|
||||
"Could not determine alias for import {}",
|
||||
import.source.display()
|
||||
)))
|
||||
.in_file(&location),
|
||||
)
|
||||
})?
|
||||
.to_str()
|
||||
.unwrap(),
|
||||
);
|
||||
|
@ -335,7 +332,6 @@ impl Importer {
|
|||
Ok(Module {
|
||||
imports: vec![],
|
||||
symbols,
|
||||
..destination
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,49 +2,36 @@ use crate::flat_absy::FlatVariable;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::btree_map::{BTreeMap, Entry};
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use std::ops::{Add, Div, Mul, Sub};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Hash)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
pub struct QuadComb<T> {
|
||||
pub left: LinComb<T>,
|
||||
pub right: LinComb<T>,
|
||||
}
|
||||
|
||||
impl<T: Field> PartialEq for QuadComb<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.left.eq(&other.left) && self.right.eq(&other.right)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Eq for QuadComb<T> {}
|
||||
|
||||
impl<T: Field> QuadComb<T> {
|
||||
pub fn from_linear_combinations(left: LinComb<T>, right: LinComb<T>) -> Self {
|
||||
QuadComb { left, right }
|
||||
}
|
||||
|
||||
pub fn try_linear(&self) -> Option<LinComb<T>> {
|
||||
// identify (k * ~ONE) * (lincomb) and return (k * lincomb)
|
||||
|
||||
match self.left.try_summand() {
|
||||
Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => {
|
||||
return Some(self.right.clone() * &coefficient);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
match self.right.try_summand() {
|
||||
Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => {
|
||||
return Some(self.left.clone() * &coefficient);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
pub fn try_linear(self) -> Result<LinComb<T>, Self> {
|
||||
// identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb)
|
||||
// if not, error out with the input
|
||||
|
||||
if self.left.is_zero() || self.right.is_zero() {
|
||||
return Some(LinComb::zero());
|
||||
return Ok(LinComb::zero());
|
||||
}
|
||||
|
||||
None
|
||||
match self.left.try_constant() {
|
||||
Ok(coefficient) => Ok(self.right * &coefficient),
|
||||
Err(left) => match self.right.try_constant() {
|
||||
Ok(coefficient) => Ok(left * &coefficient),
|
||||
Err(right) => Err(QuadComb::from_linear_combinations(left, right)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,17 +53,9 @@ impl<T: Field> fmt::Display for QuadComb<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
pub struct LinComb<T>(pub Vec<(FlatVariable, T)>);
|
||||
|
||||
impl<T: Field> PartialEq for LinComb<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.clone().into_canonical() == other.clone().into_canonical()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Eq for LinComb<T> {}
|
||||
|
||||
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
|
||||
pub struct CanonicalLinComb<T>(pub BTreeMap<FlatVariable, T>);
|
||||
|
||||
|
@ -113,36 +92,52 @@ impl<T> LinComb<T> {
|
|||
}
|
||||
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.0.len() == 0
|
||||
self.0.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> LinComb<T> {
|
||||
pub fn try_summand(&self) -> Option<(FlatVariable, T)> {
|
||||
pub fn try_constant(self) -> Result<T, Self> {
|
||||
match self.0.len() {
|
||||
// if the lincomb is empty, it is not reduceable to a summand
|
||||
0 => None,
|
||||
// if the lincomb is empty, it is reduceable to 0
|
||||
0 => Ok(T::zero()),
|
||||
_ => {
|
||||
// take the first variable in the lincomb
|
||||
let first = &self.0[0].0;
|
||||
|
||||
self.0
|
||||
.iter()
|
||||
.map(|element| {
|
||||
if first != &FlatVariable::one() {
|
||||
return Err(self);
|
||||
}
|
||||
|
||||
// all terms must contain the same variable
|
||||
if self.0.iter().all(|element| element.0 == *first) {
|
||||
Ok(self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1))
|
||||
} else {
|
||||
Err(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_summand(self) -> Result<(FlatVariable, T), Self> {
|
||||
match self.0.len() {
|
||||
// if the lincomb is empty, it is not reduceable to a summand
|
||||
0 => Err(self),
|
||||
_ => {
|
||||
// take the first variable in the lincomb
|
||||
let first = &self.0[0].0;
|
||||
|
||||
if self.0.iter().all(|element|
|
||||
// all terms must contain the same variable
|
||||
if element.0 == *first {
|
||||
// if they do, return the coefficient
|
||||
Ok(&element.1)
|
||||
} else {
|
||||
// otherwise, stop
|
||||
Err(())
|
||||
}
|
||||
})
|
||||
// collect to a Result to short circuit when we hit an error
|
||||
.collect::<Result<_, _>>()
|
||||
// we didn't hit an error, do final processing. It's fine to clone here.
|
||||
.map(|v: Vec<_>| (first.clone(), v.iter().fold(T::zero(), |acc, e| acc + *e)))
|
||||
.ok()
|
||||
element.0 == *first)
|
||||
{
|
||||
Ok((
|
||||
*first,
|
||||
self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1),
|
||||
))
|
||||
} else {
|
||||
Err(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,9 +202,7 @@ impl<T: Field> fmt::Display for LinComb<T> {
|
|||
false => write!(
|
||||
f,
|
||||
"{}",
|
||||
self.clone()
|
||||
.into_canonical()
|
||||
.0
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k))
|
||||
.collect::<Vec<_>>()
|
||||
|
@ -251,10 +244,14 @@ impl<T: Field> Mul<&T> for LinComb<T> {
|
|||
type Output = LinComb<T>;
|
||||
|
||||
fn mul(self, scalar: &T) -> LinComb<T> {
|
||||
if scalar == &T::one() {
|
||||
return self;
|
||||
}
|
||||
|
||||
LinComb(
|
||||
self.0
|
||||
.into_iter()
|
||||
.map(|(var, coeff)| (var, coeff * scalar))
|
||||
.map(|(var, coeff)| (var, coeff * scalar.clone()))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
@ -262,7 +259,8 @@ impl<T: Field> Mul<&T> for LinComb<T> {
|
|||
|
||||
impl<T: Field> Div<&T> for LinComb<T> {
|
||||
type Output = LinComb<T>;
|
||||
|
||||
// Clippy warns about multiplication in a method named div. It's okay, here, since we multiply with the inverse.
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, scalar: &T) -> LinComb<T> {
|
||||
self * &scalar.inverse_mul().unwrap()
|
||||
}
|
||||
|
@ -287,7 +285,7 @@ mod tests {
|
|||
fn add() {
|
||||
let a: LinComb<Bn128Field> = FlatVariable::new(42).into();
|
||||
let b: LinComb<Bn128Field> = FlatVariable::new(42).into();
|
||||
let c = a + b.clone();
|
||||
let c = a + b;
|
||||
|
||||
let expected_vec = vec![
|
||||
(FlatVariable::new(42), Bn128Field::from(1)),
|
||||
|
@ -300,7 +298,7 @@ mod tests {
|
|||
fn sub() {
|
||||
let a: LinComb<Bn128Field> = FlatVariable::new(42).into();
|
||||
let b: LinComb<Bn128Field> = FlatVariable::new(42).into();
|
||||
let c = a - b.clone();
|
||||
let c = a - b;
|
||||
|
||||
let expected_vec = vec![
|
||||
(FlatVariable::new(42), Bn128Field::from(1)),
|
||||
|
@ -314,7 +312,7 @@ mod tests {
|
|||
fn display() {
|
||||
let a: LinComb<Bn128Field> =
|
||||
LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21));
|
||||
assert_eq!(&a.to_string(), "3 * _21 + 1 * _42");
|
||||
assert_eq!(&a.to_string(), "1 * _42 + 3 * _21");
|
||||
let zero: LinComb<Bn128Field> = LinComb::zero();
|
||||
assert_eq!(&zero.to_string(), "0");
|
||||
}
|
||||
|
@ -350,7 +348,7 @@ mod tests {
|
|||
+ LinComb::summand(4, FlatVariable::new(33)),
|
||||
right: LinComb::summand(1, FlatVariable::new(21)),
|
||||
};
|
||||
assert_eq!(&a.to_string(), "(4 * _33 + 3 * _42) * (1 * _21)");
|
||||
assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)");
|
||||
let a: QuadComb<Bn128Field> = QuadComb {
|
||||
left: LinComb::zero(),
|
||||
right: LinComb::summand(1, FlatVariable::new(21)),
|
||||
|
@ -371,7 +369,7 @@ mod tests {
|
|||
]);
|
||||
assert_eq!(
|
||||
summand.try_summand(),
|
||||
Some((FlatVariable::new(42), Bn128Field::from(6)))
|
||||
Ok((FlatVariable::new(42), Bn128Field::from(6)))
|
||||
);
|
||||
|
||||
let not_summand = LinComb(vec![
|
||||
|
@ -379,10 +377,10 @@ mod tests {
|
|||
(FlatVariable::new(42), Bn128Field::from(2)),
|
||||
(FlatVariable::new(42), Bn128Field::from(3)),
|
||||
]);
|
||||
assert_eq!(not_summand.try_summand(), None);
|
||||
assert!(not_summand.try_summand().is_err());
|
||||
|
||||
let empty: LinComb<Bn128Field> = LinComb(vec![]);
|
||||
assert_eq!(empty.try_summand(), None);
|
||||
assert!(empty.try_summand().is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ impl<T: Field> From<FlatDirective<T>> for Directive<T> {
|
|||
inputs: ds
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(|i| QuadComb::from_flat_expression(i))
|
||||
.map(QuadComb::from_flat_expression)
|
||||
.collect(),
|
||||
solver: ds.solver,
|
||||
outputs: ds.outputs,
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use crate::flat_absy::flat_variable::FlatVariable;
|
||||
use crate::ir::Directive;
|
||||
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
|
||||
use crate::solvers::{Executable, Solver};
|
||||
use crate::solvers::Solver;
|
||||
use pairing_ce::bn256::Bn256;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt;
|
||||
#[cfg(feature = "bellman")]
|
||||
use zokrates_embed::generate_sha256_round_witness;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub type ExecutionResult<T> = Result<Witness<T>, Error>;
|
||||
|
@ -34,13 +37,13 @@ impl Interpreter {
|
|||
}
|
||||
|
||||
impl Interpreter {
|
||||
pub fn execute<T: Field>(&self, program: &Prog<T>, inputs: &Vec<T>) -> ExecutionResult<T> {
|
||||
pub fn execute<T: Field>(&self, program: &Prog<T>, inputs: &[T]) -> ExecutionResult<T> {
|
||||
let main = &program.main;
|
||||
self.check_inputs(&program, &inputs)?;
|
||||
let mut witness = BTreeMap::new();
|
||||
witness.insert(FlatVariable::one(), T::one());
|
||||
for (arg, value) in main.arguments.iter().zip(inputs.iter()) {
|
||||
witness.insert(arg.clone(), value.clone().into());
|
||||
witness.insert(*arg, value.clone());
|
||||
}
|
||||
|
||||
for statement in main.statements.iter() {
|
||||
|
@ -48,7 +51,7 @@ impl Interpreter {
|
|||
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
|
||||
true => {
|
||||
let val = quad.evaluate(&witness).unwrap();
|
||||
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
|
||||
witness.insert(lin.0.get(0).unwrap().0, val);
|
||||
}
|
||||
false => {
|
||||
let lhs_value = quad.evaluate(&witness).unwrap();
|
||||
|
@ -76,10 +79,10 @@ impl Interpreter {
|
|||
.iter()
|
||||
.map(|i| i.evaluate(&witness).unwrap())
|
||||
.collect();
|
||||
match d.solver.execute(&inputs) {
|
||||
match self.execute_solver(&d.solver, &inputs) {
|
||||
Ok(res) => {
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(o.clone(), res[i].clone());
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
@ -107,12 +110,12 @@ impl Interpreter {
|
|||
value.to_biguint()
|
||||
};
|
||||
|
||||
let mut num = input.clone();
|
||||
let mut num = input;
|
||||
let mut res = vec![];
|
||||
let bits = T::get_required_bits();
|
||||
for i in (0..bits).rev() {
|
||||
if T::from(2).to_biguint().pow(i as usize) <= num {
|
||||
num = num - T::from(2).to_biguint().pow(i as usize);
|
||||
num -= T::from(2).to_biguint().pow(i as usize);
|
||||
res.push(T::one());
|
||||
} else {
|
||||
res.push(T::zero());
|
||||
|
@ -120,11 +123,11 @@ impl Interpreter {
|
|||
}
|
||||
assert_eq!(num, T::zero().to_biguint());
|
||||
for (i, o) in d.outputs.iter().enumerate() {
|
||||
witness.insert(o.clone(), res[i].clone());
|
||||
witness.insert(*o, res[i].clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &Vec<U>) -> Result<(), Error> {
|
||||
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
|
||||
if program.main.arguments.len() == inputs.len() {
|
||||
Ok(())
|
||||
} else {
|
||||
|
@ -134,26 +137,136 @@ impl Interpreter {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_solver<T: Field>(
|
||||
&self,
|
||||
solver: &Solver,
|
||||
inputs: &[T],
|
||||
) -> Result<Vec<T>, String> {
|
||||
let (expected_input_count, expected_output_count) = solver.get_signature();
|
||||
assert!(inputs.len() == expected_input_count);
|
||||
|
||||
let res = match solver {
|
||||
Solver::ConditionEq => match inputs[0].is_zero() {
|
||||
true => vec![T::zero(), T::one()],
|
||||
false => vec![
|
||||
T::one(),
|
||||
T::one().checked_div(&inputs[0]).unwrap_or_else(T::one),
|
||||
],
|
||||
},
|
||||
Solver::Bits(bit_width) => {
|
||||
let mut num = inputs[0].clone();
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..*bit_width).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
res.push(T::one());
|
||||
} else {
|
||||
res.push(T::zero());
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
Solver::Xor => {
|
||||
let x = inputs[0].clone();
|
||||
let y = inputs[1].clone();
|
||||
|
||||
vec![x.clone() + y.clone() - T::from(2) * x * y]
|
||||
}
|
||||
Solver::Or => {
|
||||
let x = inputs[0].clone();
|
||||
let y = inputs[1].clone();
|
||||
|
||||
vec![x.clone() + y.clone() - x * y]
|
||||
}
|
||||
// res = b * c - (2b * c - b - c) * (a)
|
||||
Solver::ShaAndXorAndXorAnd => {
|
||||
let a = inputs[0].clone();
|
||||
let b = inputs[1].clone();
|
||||
let c = inputs[2].clone();
|
||||
vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a]
|
||||
}
|
||||
// res = a(b - c) + c
|
||||
Solver::ShaCh => {
|
||||
let a = inputs[0].clone();
|
||||
let b = inputs[1].clone();
|
||||
let c = inputs[2].clone();
|
||||
vec![a * (b - c.clone()) + c]
|
||||
}
|
||||
|
||||
Solver::Div => vec![inputs[0]
|
||||
.clone()
|
||||
.checked_div(&inputs[1])
|
||||
.unwrap_or_else(T::one)],
|
||||
Solver::EuclideanDiv => {
|
||||
use num::CheckedDiv;
|
||||
|
||||
let n = inputs[0].clone().to_biguint();
|
||||
let d = inputs[1].clone().to_biguint();
|
||||
|
||||
let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into());
|
||||
let r = n - d * &q;
|
||||
vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()]
|
||||
}
|
||||
#[cfg(feature = "bellman")]
|
||||
Solver::Sha256Round => {
|
||||
use zokrates_field::Bn128Field;
|
||||
assert_eq!(T::id(), Bn128Field::id());
|
||||
let i = &inputs[0..512];
|
||||
let h = &inputs[512..];
|
||||
let to_fr = |x: &T| {
|
||||
use pairing_ce::ff::{PrimeField, ScalarEngine};
|
||||
let s = x.to_dec_string();
|
||||
<Bn256 as ScalarEngine>::Fr::from_str(&s).unwrap()
|
||||
};
|
||||
let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect();
|
||||
let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect();
|
||||
assert_eq!(h.len(), 256);
|
||||
generate_sha256_round_witness::<Bn256>(&i, &h)
|
||||
.into_iter()
|
||||
.map(|x| {
|
||||
use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr};
|
||||
let mut res: Vec<u8> = vec![];
|
||||
x.into_repr().write_le(&mut res).unwrap();
|
||||
T::from_byte_vector(res)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
assert_eq!(res.len(), expected_output_count);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EvaluationError;
|
||||
|
||||
impl<T: Field> LinComb<T> {
|
||||
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, ()> {
|
||||
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, EvaluationError> {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(var, mult)| witness.get(var).map(|v| v.clone() * mult).ok_or(())) // get each term
|
||||
.map(|(var, mult)| {
|
||||
witness
|
||||
.get(var)
|
||||
.map(|v| v.clone() * mult)
|
||||
.ok_or(EvaluationError)
|
||||
}) // get each term
|
||||
.collect::<Result<Vec<_>, _>>() // fail if any term isn't found
|
||||
.map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum
|
||||
}
|
||||
|
||||
fn is_assignee<U>(&self, witness: &BTreeMap<FlatVariable, U>) -> bool {
|
||||
self.0.iter().count() == 1
|
||||
&& self.0.iter().next().unwrap().1 == T::from(1)
|
||||
&& !witness.contains_key(&self.0.iter().next().unwrap().0)
|
||||
&& self.0.get(0).unwrap().1 == T::from(1)
|
||||
&& !witness.contains_key(&self.0.get(0).unwrap().0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> QuadComb<T> {
|
||||
pub fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, ()> {
|
||||
pub fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, EvaluationError> {
|
||||
let left = self.left.evaluate(&witness)?;
|
||||
let right = self.right.evaluate(&witness)?;
|
||||
Ok(left * right)
|
||||
|
@ -192,3 +305,83 @@ impl fmt::Debug for Error {
|
|||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
mod eq_condition {
|
||||
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let cond_eq = Solver::ConditionEq;
|
||||
let inputs = vec![0];
|
||||
let interpreter = Interpreter::default();
|
||||
let r = interpreter
|
||||
.execute_solver(
|
||||
&cond_eq,
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&i| Bn128Field::from(i))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute_non_eq() {
|
||||
let cond_eq = Solver::ConditionEq;
|
||||
let inputs = vec![1];
|
||||
let interpreter = Interpreter::default();
|
||||
let r = interpreter
|
||||
.execute_solver(
|
||||
&cond_eq,
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&i| Bn128Field::from(i))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.unwrap();
|
||||
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bits_of_one() {
|
||||
let inputs = vec![Bn128Field::from(1)];
|
||||
let interpreter = Interpreter::default();
|
||||
let res = interpreter
|
||||
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], Bn128Field::from(1));
|
||||
for i in 0..253 {
|
||||
assert_eq!(res[i], Bn128Field::from(0));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bits_of_42() {
|
||||
let inputs = vec![Bn128Field::from(42)];
|
||||
let interpreter = Interpreter::default();
|
||||
let res = interpreter
|
||||
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], Bn128Field::from(0));
|
||||
assert_eq!(res[252], Bn128Field::from(1));
|
||||
assert_eq!(res[251], Bn128Field::from(0));
|
||||
assert_eq!(res[250], Bn128Field::from(1));
|
||||
assert_eq!(res[249], Bn128Field::from(0));
|
||||
assert_eq!(res[248], Bn128Field::from(1));
|
||||
assert_eq!(res[247], Bn128Field::from(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ use crate::flat_absy::FlatVariable;
|
|||
use crate::solvers::Solver;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::hash::Hash;
|
||||
use zokrates_field::Field;
|
||||
|
||||
mod expression;
|
||||
|
@ -19,26 +20,12 @@ pub use self::serialize::ProgEnum;
|
|||
pub use self::interpreter::{Error, ExecutionResult, Interpreter};
|
||||
pub use self::witness::Witness;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Hash)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
|
||||
pub enum Statement<T> {
|
||||
Constraint(QuadComb<T>, LinComb<T>),
|
||||
Directive(Directive<T>),
|
||||
}
|
||||
|
||||
impl<T: Field> PartialEq for Statement<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Statement::Constraint(l1, r1), Statement::Constraint(l2, r2)) => {
|
||||
l1.eq(l2) && r1.eq(r2)
|
||||
}
|
||||
(Statement::Directive(d1), Statement::Directive(d2)) => d1.eq(d2),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Eq for Statement<T> {}
|
||||
|
||||
impl<T: Field> Statement<T> {
|
||||
pub fn definition<U: Into<QuadComb<T>>>(v: FlatVariable, e: U) -> Self {
|
||||
Statement::Constraint(e.into(), v.into())
|
||||
|
@ -49,23 +36,13 @@ impl<T: Field> Statement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Hash)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
pub struct Directive<T> {
|
||||
pub inputs: Vec<QuadComb<T>>,
|
||||
pub outputs: Vec<FlatVariable>,
|
||||
pub solver: Solver,
|
||||
}
|
||||
|
||||
impl<T: Field> PartialEq for Directive<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.inputs.eq(&other.inputs)
|
||||
&& self.outputs.eq(&other.outputs)
|
||||
&& self.solver.eq(&other.solver)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Eq for Directive<T> {}
|
||||
|
||||
impl<T: Field> fmt::Display for Directive<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
|
@ -95,7 +72,7 @@ impl<T: Field> fmt::Display for Statement<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct Function<T> {
|
||||
pub id: String,
|
||||
pub statements: Vec<Statement<T>>,
|
||||
|
@ -103,15 +80,6 @@ pub struct Function<T> {
|
|||
pub returns: Vec<FlatVariable>,
|
||||
}
|
||||
|
||||
impl<T: Field> PartialEq for Function<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id.eq(&other.id)
|
||||
&& self.statements.eq(&other.statements)
|
||||
&& self.arguments.eq(&other.arguments)
|
||||
&& self.returns.eq(&other.returns)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> fmt::Display for Function<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
|
@ -138,27 +106,18 @@ impl<T: Field> fmt::Display for Function<T> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct Prog<T> {
|
||||
pub main: Function<T>,
|
||||
pub private: Vec<bool>,
|
||||
}
|
||||
|
||||
impl<T: Field> PartialEq for Prog<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.main.eq(&other.main) && self.private.eq(&other.private)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Field> Prog<T> {
|
||||
pub fn constraint_count(&self) -> usize {
|
||||
self.main
|
||||
.statements
|
||||
.iter()
|
||||
.filter(|s| match s {
|
||||
Statement::Constraint(..) => true,
|
||||
_ => false,
|
||||
})
|
||||
.filter(|s| matches!(s, Statement::Constraint(..)))
|
||||
.count()
|
||||
}
|
||||
|
||||
|
|
|
@ -16,9 +16,9 @@ pub enum ProgEnum {
|
|||
|
||||
impl<T: Field> Prog<T> {
|
||||
pub fn serialize<W: Write>(&self, mut w: W) {
|
||||
w.write(ZOKRATES_MAGIC).unwrap();
|
||||
w.write(ZOKRATES_VERSION_1).unwrap();
|
||||
w.write(&T::id()).unwrap();
|
||||
w.write_all(ZOKRATES_MAGIC).unwrap();
|
||||
w.write_all(ZOKRATES_VERSION_1).unwrap();
|
||||
w.write_all(&T::id()).unwrap();
|
||||
|
||||
serialize_into(&mut w, self, Infinite).unwrap();
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ impl fmt::Display for Error {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn process_macros<'ast, T: Field>(file: File<'ast>) -> Result<File<'ast>, Error> {
|
||||
pub fn process_macros<T: Field>(file: File) -> Result<File, Error> {
|
||||
match &file.pragma {
|
||||
Some(pragma) => {
|
||||
if T::name() != pragma.curve.name {
|
||||
|
|
10
zokrates_core/src/optimizer/canonicalizer.rs
Normal file
10
zokrates_core/src/optimizer/canonicalizer.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
use crate::ir::{folder::Folder, LinComb};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct Canonicalizer;
|
||||
|
||||
impl<T: Field> Folder<T> for Canonicalizer {
|
||||
fn fold_linear_combination(&mut self, l: LinComb<T>) -> LinComb<T> {
|
||||
l.into_canonical().into()
|
||||
}
|
||||
}
|
|
@ -12,10 +12,10 @@
|
|||
use crate::flat_absy::flat_variable::FlatVariable;
|
||||
use crate::ir::folder::*;
|
||||
use crate::ir::*;
|
||||
use crate::optimizer::canonicalizer::Canonicalizer;
|
||||
use crate::solvers::Solver;
|
||||
use std::collections::hash_map::{Entry, HashMap};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DirectiveOptimizer<T: Field> {
|
||||
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
|
||||
|
@ -37,6 +37,23 @@ impl<T: Field> DirectiveOptimizer<T> {
|
|||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
||||
fn fold_function(&mut self, f: Function<T>) -> Function<T> {
|
||||
// in order to correcty identify duplicates, we need to first canonicalize the statements
|
||||
|
||||
let mut canonicalizer = Canonicalizer;
|
||||
|
||||
let f = Function {
|
||||
statements: f
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| canonicalizer.fold_statement(s))
|
||||
.collect(),
|
||||
..f
|
||||
};
|
||||
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
match s {
|
||||
Statement::Directive(d) => {
|
||||
|
@ -49,7 +66,7 @@ impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
|
|||
}
|
||||
Entry::Occupied(e) => {
|
||||
self.substitution
|
||||
.extend(d.outputs.into_iter().zip(e.get().into_iter().cloned()));
|
||||
.extend(d.outputs.into_iter().zip(e.get().iter().cloned()));
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
//! Module containing the `DuplicateOptimizer` to remove duplicate constraints
|
||||
|
||||
use crate::ir::folder::Folder;
|
||||
use crate::ir::folder::*;
|
||||
use crate::ir::*;
|
||||
use crate::optimizer::canonicalizer::Canonicalizer;
|
||||
use std::collections::{hash_map::DefaultHasher, HashSet};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -33,6 +34,22 @@ impl DuplicateOptimizer {
|
|||
}
|
||||
|
||||
impl<T: Field> Folder<T> for DuplicateOptimizer {
|
||||
fn fold_function(&mut self, f: Function<T>) -> Function<T> {
|
||||
// in order to correcty identify duplicates, we need to first canonicalize the statements
|
||||
let mut canonicalizer = Canonicalizer;
|
||||
|
||||
let f = Function {
|
||||
statements: f
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| canonicalizer.fold_statement(s))
|
||||
.collect(),
|
||||
..f
|
||||
};
|
||||
|
||||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
let hashed = hash(&s);
|
||||
let result = match self.seen.get(&hashed) {
|
||||
|
@ -120,7 +137,7 @@ mod tests {
|
|||
main: Function {
|
||||
id: "main".to_string(),
|
||||
statements: vec![
|
||||
constraint.clone(),
|
||||
constraint,
|
||||
Statement::Constraint(
|
||||
QuadComb::from_linear_combinations(
|
||||
LinComb::summand(3, FlatVariable::new(42)),
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
mod canonicalizer;
|
||||
mod directive;
|
||||
mod duplicate;
|
||||
mod redefinition;
|
||||
|
@ -26,7 +27,6 @@ impl<T: Field> Prog<T> {
|
|||
// // deduplicate directives which take the same input
|
||||
let r = DirectiveOptimizer::optimize(r);
|
||||
// remove duplicate constraints
|
||||
let r = DuplicateOptimizer::optimize(r);
|
||||
r
|
||||
DuplicateOptimizer::optimize(r)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,6 @@ use crate::flat_absy::flat_variable::FlatVariable;
|
|||
use crate::ir::folder::{fold_function, Folder};
|
||||
use crate::ir::LinComb;
|
||||
use crate::ir::*;
|
||||
use crate::solvers::Executable;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -53,7 +52,7 @@ pub struct RedefinitionOptimizer<T: Field> {
|
|||
}
|
||||
|
||||
impl<T: Field> RedefinitionOptimizer<T> {
|
||||
fn new() -> RedefinitionOptimizer<T> {
|
||||
fn new() -> Self {
|
||||
RedefinitionOptimizer {
|
||||
substitution: HashMap::new(),
|
||||
ignore: HashSet::new(),
|
||||
|
@ -72,84 +71,77 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
let quad = self.fold_quadratic_combination(quad);
|
||||
let lin = self.fold_linear_combination(lin);
|
||||
|
||||
let (keep_constraint, to_insert, to_ignore) = match lin.try_summand() {
|
||||
// if the right side is a single variable
|
||||
Some((variable, coefficient)) => {
|
||||
match self.ignore.contains(&variable) {
|
||||
// if the variable isn't tagged as ignored
|
||||
false => match self.substitution.get(&variable) {
|
||||
// if the variable is already defined
|
||||
Some(_) => (true, None, None),
|
||||
// if the variable is not defined yet
|
||||
None => match quad.try_linear() {
|
||||
// if the left side is linear
|
||||
Some(l) => (false, Some((variable, l / &coefficient)), None),
|
||||
// if the left side isn't linear
|
||||
None => (true, None, Some(variable)),
|
||||
},
|
||||
},
|
||||
true => (true, None, None),
|
||||
}
|
||||
}
|
||||
None => (true, None, None),
|
||||
if lin.is_zero() {
|
||||
return vec![Statement::Constraint(quad, lin)];
|
||||
}
|
||||
|
||||
let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0)
|
||||
|| self.substitution.contains_key(&lin.0[0].0)
|
||||
{
|
||||
true => (Some(Statement::Constraint(quad, lin)), None, None),
|
||||
false => match lin.try_summand() {
|
||||
// if the right side is a single variable
|
||||
Ok((variable, coefficient)) => match quad.try_linear() {
|
||||
// if the left side is linear
|
||||
Ok(l) => (None, Some((variable, l / &coefficient)), None),
|
||||
// if the left side isn't linear
|
||||
Err(quad) => (
|
||||
Some(Statement::Constraint(
|
||||
quad,
|
||||
LinComb::summand(coefficient, variable),
|
||||
)),
|
||||
None,
|
||||
Some(variable),
|
||||
),
|
||||
},
|
||||
Err(l) => (Some(Statement::Constraint(quad, l)), None, None),
|
||||
},
|
||||
};
|
||||
|
||||
// insert into the ignored set
|
||||
match to_ignore {
|
||||
Some(v) => {
|
||||
self.ignore.insert(v);
|
||||
}
|
||||
None => {}
|
||||
if let Some(v) = to_ignore {
|
||||
self.ignore.insert(v);
|
||||
}
|
||||
|
||||
// insert into the substitution map
|
||||
match to_insert {
|
||||
Some((k, v)) => {
|
||||
self.substitution.insert(k, v.into_canonical());
|
||||
}
|
||||
None => {}
|
||||
if let Some((k, v)) = to_insert {
|
||||
self.substitution.insert(k, v.into_canonical());
|
||||
};
|
||||
|
||||
// decide whether the constraint should be kept
|
||||
match keep_constraint {
|
||||
false => vec![],
|
||||
true => vec![Statement::Constraint(quad, lin)],
|
||||
match constraint {
|
||||
Some(c) => vec![c],
|
||||
_ => vec![],
|
||||
}
|
||||
}
|
||||
Statement::Directive(d) => {
|
||||
let d = self.fold_directive(d);
|
||||
|
||||
// check if the inputs are constants, ie reduce to the form `coeff * ~one`
|
||||
let inputs = d
|
||||
let inputs: Vec<_> = d
|
||||
.inputs
|
||||
.into_iter()
|
||||
// we need to reduce to the canonical form to interpret `a + 1 - a` as `1`
|
||||
.map(|i| i.reduce())
|
||||
.map(|q| match q.try_linear() {
|
||||
Some(l) => match l.0.len() {
|
||||
// 0 is constant and can be represented by an empty lincomb
|
||||
0 => Ok(T::from(0)),
|
||||
_ => l
|
||||
// try to match to a single summand `coeff * v`
|
||||
.try_summand()
|
||||
.map(|(variable, coefficient)| match variable {
|
||||
// v must be ~one
|
||||
v if v == FlatVariable::one() => Ok(coefficient),
|
||||
_ => Err(LinComb::summand(coefficient, variable).into()),
|
||||
})
|
||||
.unwrap_or(Err(l.into())),
|
||||
},
|
||||
None => Err(q),
|
||||
.map(|q| {
|
||||
match q
|
||||
.try_linear()
|
||||
.map(|l| l.try_constant().map_err(|l| l.into()))
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Result<T, QuadComb<T>>>>();
|
||||
|
||||
match inputs.iter().all(|r| r.is_ok()) {
|
||||
match inputs.iter().all(|i| i.is_ok()) {
|
||||
true => {
|
||||
// unwrap inputs to their constant value
|
||||
let inputs = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
|
||||
// run the solver
|
||||
let outputs = d.solver.execute(&inputs).unwrap();
|
||||
|
||||
let outputs = Interpreter::default()
|
||||
.execute_solver(&d.solver, &inputs)
|
||||
.unwrap();
|
||||
assert_eq!(outputs.len(), d.outputs.len());
|
||||
|
||||
// insert the results in the substitution
|
||||
|
@ -160,8 +152,8 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
vec![]
|
||||
}
|
||||
false => {
|
||||
// reconstruct the input expressions
|
||||
let inputs = inputs
|
||||
//reconstruct the input expressions
|
||||
let inputs: Vec<_> = inputs
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
i.map(|v| LinComb::summand(v, FlatVariable::one()).into())
|
||||
|
@ -183,8 +175,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
match lc
|
||||
.0
|
||||
.iter()
|
||||
.find(|(variable, _)| self.substitution.get(&variable).is_some())
|
||||
.is_some()
|
||||
.any(|(variable, _)| self.substitution.get(&variable).is_some())
|
||||
{
|
||||
true =>
|
||||
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
|
||||
|
@ -194,7 +185,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
self.substitution
|
||||
.get(&variable)
|
||||
.map(|l| LinComb::from(l.clone()) * &coefficient)
|
||||
.unwrap_or(LinComb::summand(coefficient, variable))
|
||||
.unwrap_or_else(|| LinComb::summand(coefficient, variable))
|
||||
})
|
||||
.fold(LinComb::zero(), |acc, x| acc + x)
|
||||
}
|
||||
|
@ -209,9 +200,6 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
|
|||
}
|
||||
|
||||
fn fold_function(&mut self, fun: Function<T>) -> Function<T> {
|
||||
self.substitution.drain();
|
||||
self.ignore.drain();
|
||||
|
||||
// to prevent the optimiser from replacing outputs, add them to the ignored set
|
||||
self.ignore.extend(fun.returns.iter().cloned());
|
||||
|
||||
|
@ -378,7 +366,7 @@ mod tests {
|
|||
// ->
|
||||
|
||||
// def main(x, y) -> (1):
|
||||
// 6*x + 6*y == 6*x + 6*y // will be eliminated as a tautology
|
||||
// 1*x + 1*y + 2*x + 2*y + 3*x + 3*y == 6*x + 6*y // will be eliminated as a tautology
|
||||
// return 6*x + 6*y
|
||||
|
||||
let x = FlatVariable::new(0);
|
||||
|
@ -412,7 +400,15 @@ mod tests {
|
|||
LinComb::summand(6, x) + LinComb::summand(6, y),
|
||||
LinComb::summand(6, x) + LinComb::summand(6, y),
|
||||
),
|
||||
Statement::definition(r, LinComb::summand(6, x) + LinComb::summand(6, y)),
|
||||
Statement::definition(
|
||||
r,
|
||||
LinComb::summand(1, x)
|
||||
+ LinComb::summand(1, y)
|
||||
+ LinComb::summand(2, x)
|
||||
+ LinComb::summand(2, y)
|
||||
+ LinComb::summand(3, x)
|
||||
+ LinComb::summand(3, y),
|
||||
),
|
||||
],
|
||||
returns: vec![r],
|
||||
};
|
||||
|
|
|
@ -25,17 +25,16 @@ impl TautologyOptimizer {
|
|||
impl<T: Field> Folder<T> for TautologyOptimizer {
|
||||
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
|
||||
match s {
|
||||
Statement::Constraint(quad, lin) => {
|
||||
match quad.try_linear() {
|
||||
Some(l) => {
|
||||
if l == lin {
|
||||
return vec![];
|
||||
}
|
||||
Statement::Constraint(quad, lin) => match quad.try_linear() {
|
||||
Ok(l) => {
|
||||
if l == lin {
|
||||
vec![]
|
||||
} else {
|
||||
vec![Statement::Constraint(l.into(), lin)]
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
vec![Statement::Constraint(quad, lin)]
|
||||
}
|
||||
Err(quad) => vec![Statement::Constraint(quad, lin)],
|
||||
},
|
||||
_ => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ impl<T: Field + ArkFieldExtensions + NotBw6_761Field> Backend<T, GM17> for Ark {
|
|||
query: vk
|
||||
.query
|
||||
.into_iter()
|
||||
.map(|g1| serialization::to_g1::<T>(g1))
|
||||
.map(serialization::to_g1::<T>)
|
||||
.collect(),
|
||||
};
|
||||
|
||||
|
@ -172,7 +172,7 @@ impl Backend<Bw6_761Field, GM17> for Ark {
|
|||
query: vk
|
||||
.query
|
||||
.into_iter()
|
||||
.map(|g1| serialization::to_g1::<Bw6_761Field>(g1))
|
||||
.map(serialization::to_g1::<Bw6_761Field>)
|
||||
.collect(),
|
||||
};
|
||||
|
||||
|
|
|
@ -49,42 +49,33 @@ fn ark_combination<T: Field + ArkFieldExtensions>(
|
|||
cs: &mut ConstraintSystem<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>,
|
||||
symbols: &mut BTreeMap<FlatVariable, Variable>,
|
||||
witness: &mut Witness<T>,
|
||||
) -> Result<
|
||||
LinearCombination<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>,
|
||||
SynthesisError,
|
||||
> {
|
||||
let lc =
|
||||
l.0.into_iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
v.into_ark(),
|
||||
symbols
|
||||
.entry(k)
|
||||
.or_insert_with(|| {
|
||||
match k.is_output() {
|
||||
true => cs.new_input_variable(|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_ark())
|
||||
}),
|
||||
false => cs.new_witness_variable(|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_ark())
|
||||
}),
|
||||
}
|
||||
.unwrap()
|
||||
})
|
||||
.clone(),
|
||||
)
|
||||
})
|
||||
.fold(LinearCombination::zero(), |acc, e| acc + e);
|
||||
|
||||
Ok(lc)
|
||||
) -> LinearCombination<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr> {
|
||||
l.0.into_iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
v.into_ark(),
|
||||
*symbols.entry(k).or_insert_with(|| {
|
||||
match k.is_output() {
|
||||
true => cs.new_input_variable(|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_ark())
|
||||
}),
|
||||
false => cs.new_witness_variable(|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_ark())
|
||||
}),
|
||||
}
|
||||
.unwrap()
|
||||
}),
|
||||
)
|
||||
})
|
||||
.fold(LinearCombination::zero(), |acc, e| acc + e)
|
||||
}
|
||||
|
||||
impl<T: Field + ArkFieldExtensions> Prog<T> {
|
||||
|
@ -96,7 +87,7 @@ impl<T: Field + ArkFieldExtensions> Prog<T> {
|
|||
// mapping from IR variables
|
||||
let mut symbols = BTreeMap::new();
|
||||
|
||||
let mut witness = witness.unwrap_or(Witness::empty());
|
||||
let mut witness = witness.unwrap_or_else(Witness::empty);
|
||||
|
||||
assert!(symbols.insert(FlatVariable::one(), ConstraintSystem::<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>::one()).is_none());
|
||||
|
||||
|
@ -127,37 +118,34 @@ impl<T: Field + ArkFieldExtensions> Prog<T> {
|
|||
}),
|
||||
}
|
||||
.unwrap();
|
||||
(var.clone(), wire)
|
||||
(*var, wire)
|
||||
}),
|
||||
);
|
||||
|
||||
let main = self.main;
|
||||
|
||||
for statement in main.statements {
|
||||
match statement {
|
||||
Statement::Constraint(quad, lin) => {
|
||||
let a = ark_combination(
|
||||
quad.left.clone().into_canonical(),
|
||||
&mut cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
)?;
|
||||
let b = ark_combination(
|
||||
quad.right.clone().into_canonical(),
|
||||
&mut cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
)?;
|
||||
let c = ark_combination(
|
||||
lin.into_canonical(),
|
||||
&mut cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
)?;
|
||||
if let Statement::Constraint(quad, lin) = statement {
|
||||
let a = ark_combination(
|
||||
quad.left.clone().into_canonical(),
|
||||
&mut cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
let b = ark_combination(
|
||||
quad.right.clone().into_canonical(),
|
||||
&mut cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
let c = ark_combination(
|
||||
lin.into_canonical(),
|
||||
&mut cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
|
||||
cs.enforce_constraint(a, b, c)?;
|
||||
}
|
||||
_ => {}
|
||||
cs.enforce_constraint(a, b, c)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
|
|||
ic: vk
|
||||
.gamma_abc
|
||||
.into_iter()
|
||||
.map(|g1| serialization::to_g1::<T>(g1))
|
||||
.map(serialization::to_g1::<T>)
|
||||
.collect(),
|
||||
};
|
||||
|
||||
|
|
|
@ -51,34 +51,31 @@ fn bellman_combination<T: BellmanFieldExtensions, CS: ConstraintSystem<T::Bellma
|
|||
.map(|(k, v)| {
|
||||
(
|
||||
v.into_bellman(),
|
||||
symbols
|
||||
.entry(k)
|
||||
.or_insert_with(|| {
|
||||
match k.is_output() {
|
||||
true => cs.alloc_input(
|
||||
|| format!("{}", k),
|
||||
|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_bellman())
|
||||
},
|
||||
),
|
||||
false => cs.alloc(
|
||||
|| format!("{}", k),
|
||||
|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_bellman())
|
||||
},
|
||||
),
|
||||
}
|
||||
.unwrap()
|
||||
})
|
||||
.clone(),
|
||||
*symbols.entry(k).or_insert_with(|| {
|
||||
match k.is_output() {
|
||||
true => cs.alloc_input(
|
||||
|| format!("{}", k),
|
||||
|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_bellman())
|
||||
},
|
||||
),
|
||||
false => cs.alloc(
|
||||
|| format!("{}", k),
|
||||
|| {
|
||||
Ok(witness
|
||||
.0
|
||||
.remove(&k)
|
||||
.ok_or(SynthesisError::AssignmentMissing)?
|
||||
.into_bellman())
|
||||
},
|
||||
),
|
||||
}
|
||||
.unwrap()
|
||||
}),
|
||||
)
|
||||
})
|
||||
.fold(LinearCombination::zero(), |acc, e| acc + e)
|
||||
|
@ -93,7 +90,7 @@ impl<T: BellmanFieldExtensions + Field> Prog<T> {
|
|||
// mapping from IR variables
|
||||
let mut symbols = BTreeMap::new();
|
||||
|
||||
let mut witness = witness.unwrap_or(Witness::empty());
|
||||
let mut witness = witness.unwrap_or_else(Witness::empty);
|
||||
|
||||
assert!(symbols.insert(FlatVariable::one(), CS::one()).is_none());
|
||||
|
||||
|
@ -127,33 +124,29 @@ impl<T: BellmanFieldExtensions + Field> Prog<T> {
|
|||
),
|
||||
}
|
||||
.unwrap();
|
||||
(var.clone(), wire)
|
||||
(*var, wire)
|
||||
}),
|
||||
);
|
||||
|
||||
let main = self.main;
|
||||
|
||||
for statement in main.statements {
|
||||
match statement {
|
||||
Statement::Constraint(quad, lin) => {
|
||||
let a = &bellman_combination(
|
||||
quad.left.into_canonical(),
|
||||
cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
let b = &bellman_combination(
|
||||
quad.right.into_canonical(),
|
||||
cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
let c =
|
||||
&bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness);
|
||||
if let Statement::Constraint(quad, lin) = statement {
|
||||
let a = &bellman_combination(
|
||||
quad.left.into_canonical(),
|
||||
cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
let b = &bellman_combination(
|
||||
quad.right.into_canonical(),
|
||||
cs,
|
||||
&mut symbols,
|
||||
&mut witness,
|
||||
);
|
||||
let c = &bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness);
|
||||
|
||||
cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c);
|
||||
}
|
||||
_ => {}
|
||||
cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,10 +1,5 @@
|
|||
#[cfg(feature = "bellman")]
|
||||
use pairing_ce::bn256::Bn256;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
#[cfg(feature = "bellman")]
|
||||
use zokrates_embed::generate_sha256_round_witness;
|
||||
use zokrates_field::{Bn128Field, Field};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
||||
pub enum Solver {
|
||||
|
@ -48,168 +43,3 @@ impl Solver {
|
|||
Solver::Bits(width)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Executable<T> {
|
||||
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String>;
|
||||
}
|
||||
|
||||
impl<T: Field> Executable<T> for Solver {
|
||||
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
|
||||
let (expected_input_count, expected_output_count) = self.get_signature();
|
||||
assert_eq!(inputs.len(), expected_input_count);
|
||||
|
||||
let res = match self {
|
||||
Solver::ConditionEq => match inputs[0].is_zero() {
|
||||
true => vec![T::zero(), T::one()],
|
||||
false => vec![
|
||||
T::one(),
|
||||
T::one().checked_div(&inputs[0]).unwrap_or(T::one()),
|
||||
],
|
||||
},
|
||||
Solver::Bits(bit_width) => {
|
||||
let mut num = inputs[0].clone();
|
||||
let mut res = vec![];
|
||||
|
||||
for i in (0..*bit_width).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
res.push(T::one());
|
||||
} else {
|
||||
res.push(T::zero());
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
Solver::Xor => {
|
||||
let x = inputs[0].clone();
|
||||
let y = inputs[1].clone();
|
||||
|
||||
vec![x.clone() + y.clone() - T::from(2) * x * y]
|
||||
}
|
||||
Solver::Or => {
|
||||
let x = inputs[0].clone();
|
||||
let y = inputs[1].clone();
|
||||
|
||||
vec![x.clone() + y.clone() - x * y]
|
||||
}
|
||||
// res = b * c - (2b * c - b - c) * (a)
|
||||
Solver::ShaAndXorAndXorAnd => {
|
||||
let a = inputs[0].clone();
|
||||
let b = inputs[1].clone();
|
||||
let c = inputs[2].clone();
|
||||
vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a]
|
||||
}
|
||||
// res = a(b - c) + c
|
||||
Solver::ShaCh => {
|
||||
let a = inputs[0].clone();
|
||||
let b = inputs[1].clone();
|
||||
let c = inputs[2].clone();
|
||||
vec![a * (b - c.clone()) + c]
|
||||
}
|
||||
Solver::Div => vec![inputs[0]
|
||||
.clone()
|
||||
.checked_div(&inputs[1])
|
||||
.unwrap_or(T::one())],
|
||||
Solver::EuclideanDiv => {
|
||||
use num::CheckedDiv;
|
||||
|
||||
let n = inputs[0].clone().to_biguint();
|
||||
let d = inputs[1].clone().to_biguint();
|
||||
|
||||
let q = n.checked_div(&d).unwrap_or(0u32.into());
|
||||
let r = n - d * &q;
|
||||
vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()]
|
||||
}
|
||||
#[cfg(feature = "bellman")]
|
||||
Solver::Sha256Round => {
|
||||
assert_eq!(T::id(), Bn128Field::id());
|
||||
let i = &inputs[0..512];
|
||||
let h = &inputs[512..];
|
||||
let to_fr = |x: &T| {
|
||||
use pairing_ce::ff::{PrimeField, ScalarEngine};
|
||||
let s = x.to_dec_string();
|
||||
<Bn256 as ScalarEngine>::Fr::from_str(&s).unwrap()
|
||||
};
|
||||
let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect();
|
||||
let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect();
|
||||
assert_eq!(h.len(), 256);
|
||||
generate_sha256_round_witness::<Bn256>(&i, &h)
|
||||
.into_iter()
|
||||
.map(|x| {
|
||||
use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr};
|
||||
let mut res: Vec<u8> = vec![];
|
||||
x.into_repr().write_le(&mut res).unwrap();
|
||||
T::from_byte_vector(res)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
assert_eq!(res.len(), expected_output_count);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
mod eq_condition {
|
||||
|
||||
// Wanted: (Y = (X != 0) ? 1 : 0)
|
||||
// # Y = if X == 0 then 0 else 1 fi
|
||||
// # M = if X == 0 then 1 else 1/X fi
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn execute() {
|
||||
let cond_eq = Solver::ConditionEq;
|
||||
let inputs = vec![0];
|
||||
let r = cond_eq
|
||||
.execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect())
|
||||
.unwrap();
|
||||
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute_non_eq() {
|
||||
let cond_eq = Solver::ConditionEq;
|
||||
let inputs = vec![1];
|
||||
let r = cond_eq
|
||||
.execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect())
|
||||
.unwrap();
|
||||
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
|
||||
assert_eq!(r, &res[..]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bits_of_one() {
|
||||
let bits = Solver::Bits(Bn128Field::get_required_bits());
|
||||
let inputs = vec![Bn128Field::from(1)];
|
||||
let res = bits.execute(&inputs).unwrap();
|
||||
assert_eq!(res[253], Bn128Field::from(1));
|
||||
for i in 0..253 {
|
||||
assert_eq!(res[i], Bn128Field::from(0));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bits_of_42() {
|
||||
let bits = Solver::Bits(Bn128Field::get_required_bits());
|
||||
let inputs = vec![Bn128Field::from(42)];
|
||||
let res = bits.execute(&inputs).unwrap();
|
||||
|
||||
assert_eq!(res[253], Bn128Field::from(0));
|
||||
assert_eq!(res[252], Bn128Field::from(1));
|
||||
assert_eq!(res[251], Bn128Field::from(0));
|
||||
assert_eq!(res[250], Bn128Field::from(1));
|
||||
assert_eq!(res[249], Bn128Field::from(0));
|
||||
assert_eq!(res[248], Bn128Field::from(1));
|
||||
assert_eq!(res[247], Bn128Field::from(0));
|
||||
}
|
||||
}
|
||||
|
|
134
zokrates_core/src/static_analysis/bounds_checker.rs
Normal file
134
zokrates_core/src/static_analysis/bounds_checker.rs
Normal file
|
@ -0,0 +1,134 @@
|
|||
use crate::typed_absy::result_folder::*;
|
||||
use crate::typed_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct BoundsChecker;
|
||||
|
||||
pub type Error = String;
|
||||
|
||||
impl BoundsChecker {
|
||||
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
BoundsChecker.fold_program(p)
|
||||
}
|
||||
|
||||
pub fn check_select<'ast, T: Field, U: Select<'ast, T>>(
|
||||
&mut self,
|
||||
array: ArrayExpression<'ast, T>,
|
||||
index: UExpression<'ast, T>,
|
||||
) -> Result<U, Error> {
|
||||
let array = self.fold_array_expression(array)?;
|
||||
let index = self.fold_uint_expression(index)?;
|
||||
|
||||
match (array.get_array_type().size.as_inner(), index.as_inner()) {
|
||||
(UExpressionInner::Value(size), UExpressionInner::Value(index)) => {
|
||||
if index >= size {
|
||||
return Err(format!(
|
||||
"Out of bounds access: {}[{}] but {} is of size {}",
|
||||
array, index, array, size
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(U::select(array, index))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &ArrayType<'ast, T>,
|
||||
e: ArrayExpressionInner<'ast, T>,
|
||||
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
ArrayExpressionInner::Select(box array, box index) => self
|
||||
.check_select::<_, ArrayExpression<_>>(array, index)
|
||||
.map(|a| a.into_inner()),
|
||||
ArrayExpressionInner::Slice(box array, box from, box to) => {
|
||||
let array = self.fold_array_expression(array)?;
|
||||
let from = self.fold_uint_expression(from)?;
|
||||
let to = self.fold_uint_expression(to)?;
|
||||
|
||||
match (
|
||||
array.get_array_type().size.as_inner(),
|
||||
from.as_inner(),
|
||||
to.as_inner(),
|
||||
) {
|
||||
(
|
||||
UExpressionInner::Value(size),
|
||||
UExpressionInner::Value(from),
|
||||
UExpressionInner::Value(to),
|
||||
) => {
|
||||
if from > to {
|
||||
return Err(format!(
|
||||
"Slice is created from an invalid range {}..{}",
|
||||
from, to
|
||||
));
|
||||
}
|
||||
|
||||
if from > size {
|
||||
return Err(format!("Lower bound {} of slice {}[{}..{}] is out of bounds for array of size {}", from, array, from, to, size));
|
||||
}
|
||||
|
||||
if to > size {
|
||||
return Err(format!("Upper bound {} of slice {}[{}..{}] is out of bounds for array of size {}", to, array, from, to, size));
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
|
||||
}
|
||||
e => fold_array_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType<'ast, T>,
|
||||
e: StructExpressionInner<'ast, T>,
|
||||
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
StructExpressionInner::Select(box array, box index) => self
|
||||
.check_select::<_, StructExpression<_>>(array, index)
|
||||
.map(|a| a.into_inner()),
|
||||
e => fold_struct_expression_inner(self, ty, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: FieldElementExpression<'ast, T>,
|
||||
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
FieldElementExpression::Select(box array, box index) => self.check_select(array, index),
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
BooleanExpression::Select(box array, box index) => self.check_select(array, index),
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
|
||||
match e {
|
||||
UExpressionInner::Select(box array, box index) => self
|
||||
.check_select::<_, UExpression<_>>(array, index)
|
||||
.map(|a| a.into_inner()),
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,31 +1,34 @@
|
|||
use crate::typed_absy;
|
||||
use crate::typed_absy::types::{StructType, UBitwidth};
|
||||
use crate::typed_absy::types::UBitwidth;
|
||||
use crate::zir;
|
||||
use std::marker::PhantomData;
|
||||
use zokrates_field::Field;
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
pub struct Flattener<T: Field> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
fn flatten_identifier_rec<'a>(
|
||||
id: zir::SourceIdentifier<'a>,
|
||||
ty: &typed_absy::Type,
|
||||
) -> Vec<zir::Variable<'a>> {
|
||||
fn flatten_identifier_rec<'ast>(
|
||||
id: zir::SourceIdentifier<'ast>,
|
||||
ty: &typed_absy::types::ConcreteType,
|
||||
) -> Vec<zir::Variable<'ast>> {
|
||||
match ty {
|
||||
typed_absy::Type::FieldElement => vec![zir::Variable {
|
||||
typed_absy::ConcreteType::Int => unreachable!(),
|
||||
typed_absy::ConcreteType::FieldElement => vec![zir::Variable {
|
||||
id: zir::Identifier::Source(id),
|
||||
_type: zir::Type::FieldElement,
|
||||
}],
|
||||
typed_absy::Type::Boolean => vec![zir::Variable {
|
||||
typed_absy::types::ConcreteType::Boolean => vec![zir::Variable {
|
||||
id: zir::Identifier::Source(id),
|
||||
_type: zir::Type::Boolean,
|
||||
}],
|
||||
typed_absy::Type::Uint(bitwidth) => vec![zir::Variable {
|
||||
typed_absy::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable {
|
||||
id: zir::Identifier::Source(id),
|
||||
_type: zir::Type::uint(bitwidth.to_usize()),
|
||||
}],
|
||||
typed_absy::Type::Array(array_type) => (0..array_type.size)
|
||||
typed_absy::types::ConcreteType::Array(array_type) => (0..array_type.size)
|
||||
.flat_map(|i| {
|
||||
flatten_identifier_rec(
|
||||
zir::SourceIdentifier::Select(box id.clone(), i),
|
||||
|
@ -33,7 +36,7 @@ fn flatten_identifier_rec<'a>(
|
|||
)
|
||||
})
|
||||
.collect(),
|
||||
typed_absy::Type::Struct(members) => members
|
||||
typed_absy::types::ConcreteType::Struct(members) => members
|
||||
.iter()
|
||||
.flat_map(|struct_member| {
|
||||
flatten_identifier_rec(
|
||||
|
@ -57,17 +60,6 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
fold_program(self, p)
|
||||
}
|
||||
|
||||
fn fold_module(&mut self, p: typed_absy::TypedModule<'ast, T>) -> zir::ZirModule<'ast, T> {
|
||||
fold_module(self, p)
|
||||
}
|
||||
|
||||
fn fold_function_symbol(
|
||||
&mut self,
|
||||
s: typed_absy::TypedFunctionSymbol<'ast, T>,
|
||||
) -> zir::ZirFunctionSymbol<'ast, T> {
|
||||
fold_function_symbol(self, s)
|
||||
}
|
||||
|
||||
fn fold_function(
|
||||
&mut self,
|
||||
f: typed_absy::TypedFunction<'ast, T>,
|
||||
|
@ -75,9 +67,12 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
fold_function(self, f)
|
||||
}
|
||||
|
||||
fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> {
|
||||
fn fold_declaration_parameter(
|
||||
&mut self,
|
||||
p: typed_absy::DeclarationParameter<'ast>,
|
||||
) -> Vec<zir::Parameter<'ast>> {
|
||||
let private = p.private;
|
||||
self.fold_variable(p.id)
|
||||
self.fold_variable(p.id.try_into().unwrap())
|
||||
.into_iter()
|
||||
.map(|v| zir::Parameter { id: v, private })
|
||||
.collect()
|
||||
|
@ -87,10 +82,12 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
zir::SourceIdentifier::Basic(n)
|
||||
}
|
||||
|
||||
fn fold_variable(&mut self, v: typed_absy::Variable<'ast>) -> Vec<zir::Variable<'ast>> {
|
||||
fn fold_variable(&mut self, v: typed_absy::Variable<'ast, T>) -> Vec<zir::Variable<'ast>> {
|
||||
let id = self.fold_name(v.id.clone());
|
||||
let ty = v.get_type();
|
||||
|
||||
let ty = typed_absy::types::ConcreteType::try_from(ty).unwrap();
|
||||
|
||||
flatten_identifier_rec(id, &ty)
|
||||
}
|
||||
|
||||
|
@ -102,36 +99,34 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
typed_absy::TypedAssignee::Identifier(v) => self.fold_variable(v),
|
||||
typed_absy::TypedAssignee::Select(box a, box i) => {
|
||||
use typed_absy::Typed;
|
||||
let count = match a.get_type() {
|
||||
typed_absy::Type::Array(array_ty) => array_ty.ty.get_primitive_count(),
|
||||
let count = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap() {
|
||||
typed_absy::ConcreteType::Array(array_ty) => array_ty.ty.get_primitive_count(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let a = self.fold_assignee(a);
|
||||
|
||||
match i {
|
||||
typed_absy::FieldElementExpression::Number(n) => {
|
||||
let index = n.to_dec_string().parse::<usize>().unwrap();
|
||||
a[index * count..(index + 1) * count].to_vec()
|
||||
match i.as_inner() {
|
||||
typed_absy::UExpressionInner::Value(index) => {
|
||||
a[*index as usize * count..(*index as usize + 1) * count].to_vec()
|
||||
}
|
||||
i => unreachable!("index {} not allowed, should be a constant", i),
|
||||
i => unreachable!("index {:?} not allowed, should be a constant", i),
|
||||
}
|
||||
}
|
||||
typed_absy::TypedAssignee::Member(box a, m) => {
|
||||
use typed_absy::Typed;
|
||||
|
||||
let (offset, size) = match a.get_type() {
|
||||
typed_absy::Type::Struct(struct_type) => {
|
||||
struct_type
|
||||
.members
|
||||
.iter()
|
||||
.fold((0, None), |(offset, size), member| match size {
|
||||
Some(_) => (offset, size),
|
||||
None => match m == member.id {
|
||||
true => (offset, Some(member.ty.get_primitive_count())),
|
||||
false => (offset + member.ty.get_primitive_count(), None),
|
||||
},
|
||||
})
|
||||
}
|
||||
let (offset, size) = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap()
|
||||
{
|
||||
typed_absy::ConcreteType::Struct(struct_type) => struct_type
|
||||
.members
|
||||
.iter()
|
||||
.fold((0, None), |(offset, size), member| match size {
|
||||
Some(_) => (offset, size),
|
||||
None => match m == member.id {
|
||||
true => (offset, Some(member.ty.get_primitive_count())),
|
||||
false => (offset + member.ty.get_primitive_count(), None),
|
||||
},
|
||||
}),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
|
@ -151,6 +146,16 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
fold_statement(self, s)
|
||||
}
|
||||
|
||||
fn fold_expression_or_spread(
|
||||
&mut self,
|
||||
e: typed_absy::TypedExpressionOrSpread<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match e {
|
||||
typed_absy::TypedExpressionOrSpread::Expression(e) => self.fold_expression(e),
|
||||
typed_absy::TypedExpressionOrSpread::Spread(s) => self.fold_array_expression(s.array),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_expression(
|
||||
&mut self,
|
||||
e: typed_absy::TypedExpression<'ast, T>,
|
||||
|
@ -161,8 +166,9 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
}
|
||||
typed_absy::TypedExpression::Boolean(e) => vec![self.fold_boolean_expression(e).into()],
|
||||
typed_absy::TypedExpression::Uint(e) => vec![self.fold_uint_expression(e).into()],
|
||||
typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e).into(),
|
||||
typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e).into(),
|
||||
typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e),
|
||||
typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e),
|
||||
typed_absy::TypedExpression::Int(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,26 +191,20 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
es: typed_absy::TypedExpressionList<'ast, T>,
|
||||
) -> zir::ZirExpressionList<'ast, T> {
|
||||
match es {
|
||||
typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => {
|
||||
zir::ZirExpressionList::FunctionCall(
|
||||
self.fold_function_key(id),
|
||||
typed_absy::TypedExpressionList::EmbedCall(embed, generics, arguments, _) => {
|
||||
zir::ZirExpressionList::EmbedCall(
|
||||
embed,
|
||||
generics,
|
||||
arguments
|
||||
.into_iter()
|
||||
.flat_map(|a| self.fold_expression(a))
|
||||
.collect(),
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
_ => unreachable!("should have been inlined"),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_function_key(
|
||||
&mut self,
|
||||
k: typed_absy::types::FunctionKey<'ast>,
|
||||
) -> zir::types::FunctionKey<'ast> {
|
||||
k.into()
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
e: typed_absy::FieldElementExpression<'ast, T>,
|
||||
|
@ -234,7 +234,7 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
|
||||
fn fold_array_expression_inner(
|
||||
&mut self,
|
||||
ty: &typed_absy::Type,
|
||||
ty: &typed_absy::types::ConcreteType,
|
||||
size: usize,
|
||||
e: typed_absy::ArrayExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
|
@ -242,26 +242,13 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
}
|
||||
fn fold_struct_expression_inner(
|
||||
&mut self,
|
||||
ty: &StructType,
|
||||
ty: &typed_absy::types::ConcreteStructType,
|
||||
e: typed_absy::StructExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
fold_struct_expression_inner(self, ty, e)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_module<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
p: typed_absy::TypedModule<'ast, T>,
|
||||
) -> zir::ZirModule<'ast, T> {
|
||||
zir::ZirModule {
|
||||
functions: p
|
||||
.functions
|
||||
.into_iter()
|
||||
.map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun)))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_statement<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
s: typed_absy::TypedStatement<'ast, T>,
|
||||
|
@ -284,9 +271,7 @@ pub fn fold_statement<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::TypedStatement::Declaration(v) => {
|
||||
let v = f.fold_variable(v);
|
||||
v.into_iter()
|
||||
.map(|v| zir::ZirStatement::Declaration(v))
|
||||
.collect()
|
||||
v.into_iter().map(zir::ZirStatement::Declaration).collect()
|
||||
}
|
||||
typed_absy::TypedStatement::Assertion(e) => {
|
||||
let e = f.fold_boolean_expression(e);
|
||||
|
@ -302,19 +287,23 @@ pub fn fold_statement<'ast, T: Field>(
|
|||
f.fold_expression_list(elist),
|
||||
)]
|
||||
}
|
||||
typed_absy::TypedStatement::PushCallLog(..) => vec![],
|
||||
typed_absy::TypedStatement::PopCallLog => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_array_expression_inner<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
t: &typed_absy::Type,
|
||||
ty: &typed_absy::types::ConcreteType,
|
||||
size: usize,
|
||||
e: typed_absy::ArrayExpressionInner<'ast, T>,
|
||||
array: typed_absy::ArrayExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match e {
|
||||
match array {
|
||||
typed_absy::ArrayExpressionInner::Identifier(id) => {
|
||||
let variables =
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::array(t.clone(), size));
|
||||
let variables = flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
&typed_absy::types::ConcreteType::array((ty.clone(), size)),
|
||||
);
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|v| match v._type {
|
||||
|
@ -326,10 +315,16 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
|||
})
|
||||
.collect()
|
||||
}
|
||||
typed_absy::ArrayExpressionInner::Value(exprs) => exprs
|
||||
.into_iter()
|
||||
.flat_map(|e| f.fold_expression(e))
|
||||
.collect(),
|
||||
typed_absy::ArrayExpressionInner::Value(exprs) => {
|
||||
let exprs: Vec<_> = exprs
|
||||
.into_iter()
|
||||
.flat_map(|e| f.fold_expression_or_spread(e))
|
||||
.collect();
|
||||
|
||||
assert_eq!(exprs.len(), size * ty.get_primitive_count());
|
||||
|
||||
exprs
|
||||
}
|
||||
typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(),
|
||||
typed_absy::ArrayExpressionInner::IfElse(
|
||||
box condition,
|
||||
|
@ -369,40 +364,74 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.map(|member| {
|
||||
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.sum();
|
||||
|
||||
// we also need the size of this member
|
||||
let size = t.get_primitive_count() * size;
|
||||
let size = ty.get_primitive_count() * size;
|
||||
|
||||
s[offset..offset + size].to_vec()
|
||||
}
|
||||
typed_absy::ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_field_expression(index);
|
||||
let index = f.fold_uint_expression(index);
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => {
|
||||
let size = t.get_primitive_count() * size;
|
||||
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => {
|
||||
let size = ty.clone().get_primitive_count() * size;
|
||||
let start = i as usize * size;
|
||||
let end = start + size;
|
||||
array[start..end].to_vec()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
typed_absy::ArrayExpressionInner::Slice(box array, box from, box to) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let from = f.fold_uint_expression(from);
|
||||
let to = f.fold_uint_expression(to);
|
||||
|
||||
match (from.into_inner(), to.into_inner()) {
|
||||
(zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => {
|
||||
assert_eq!(size, to.saturating_sub(from) as usize);
|
||||
|
||||
let element_size = ty.get_primitive_count();
|
||||
let start = from as usize * element_size;
|
||||
let end = to as usize * element_size;
|
||||
array[start..end].to_vec()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
typed_absy::ArrayExpressionInner::Repeat(box e, box count) => {
|
||||
let e = f.fold_expression(e);
|
||||
let count = f.fold_uint_expression(count);
|
||||
|
||||
match count.into_inner() {
|
||||
zir::UExpressionInner::Value(count) => {
|
||||
vec![e; count as usize].into_iter().flatten().collect()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fold_struct_expression_inner<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
t: &StructType,
|
||||
e: typed_absy::StructExpressionInner<'ast, T>,
|
||||
ty: &typed_absy::types::ConcreteStructType,
|
||||
struc: typed_absy::StructExpressionInner<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
match e {
|
||||
match struc {
|
||||
typed_absy::StructExpressionInner::Identifier(id) => {
|
||||
let variables =
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::struc(t.clone()));
|
||||
let variables = flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
&typed_absy::types::ConcreteType::struc(ty.clone()),
|
||||
);
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|v| match v._type {
|
||||
|
@ -457,13 +486,18 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.map(|member| {
|
||||
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.sum();
|
||||
|
||||
// we also need the size of this member
|
||||
let size = t
|
||||
let size = ty
|
||||
.iter()
|
||||
.find(|member| member.id == id)
|
||||
.cloned()
|
||||
.unwrap()
|
||||
.ty
|
||||
.get_primitive_count();
|
||||
|
@ -472,15 +506,12 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::StructExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_field_expression(index);
|
||||
let index = f.fold_uint_expression(index);
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => {
|
||||
let size = t
|
||||
.iter()
|
||||
.map(|m| m.ty.get_primitive_count())
|
||||
.fold(0, |acc, current| acc + current);
|
||||
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => {
|
||||
let size: usize = ty.iter().map(|m| m.ty.get_primitive_count()).sum();
|
||||
let start = i as usize * size;
|
||||
let end = start + size;
|
||||
array[start..end].to_vec()
|
||||
}
|
||||
|
@ -498,9 +529,12 @@ pub fn fold_field_expression<'ast, T: Field>(
|
|||
typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
|
||||
typed_absy::FieldElementExpression::Identifier(id) => {
|
||||
zir::FieldElementExpression::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::FieldElement)[0]
|
||||
.id
|
||||
.clone(),
|
||||
flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
&typed_absy::types::ConcreteType::FieldElement,
|
||||
)[0]
|
||||
.id
|
||||
.clone(),
|
||||
)
|
||||
}
|
||||
typed_absy::FieldElementExpression::Add(box e1, box e2) => {
|
||||
|
@ -525,7 +559,7 @@ pub fn fold_field_expression<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::FieldElementExpression::Pow(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::FieldElementExpression::Pow(box e1, box e2)
|
||||
}
|
||||
typed_absy::FieldElementExpression::Neg(box e) => {
|
||||
|
@ -556,26 +590,22 @@ pub fn fold_field_expression<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.map(|member| {
|
||||
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.sum();
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
s[offset].clone().try_into().unwrap()
|
||||
}
|
||||
typed_absy::FieldElementExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
|
||||
let index = f.fold_field_expression(index);
|
||||
let index = f.fold_uint_expression(index);
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => array
|
||||
[i.to_dec_string().parse::<usize>().unwrap()]
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
|
||||
_ => unreachable!(""),
|
||||
}
|
||||
}
|
||||
|
@ -589,7 +619,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
match e {
|
||||
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
|
||||
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Boolean)[0]
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::types::ConcreteType::Boolean)[0]
|
||||
.id
|
||||
.clone(),
|
||||
),
|
||||
|
@ -665,25 +695,45 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
|
||||
zir::BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::Lt(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::Lt(box e1, box e2)
|
||||
zir::BooleanExpression::FieldLt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::Le(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::Le(box e1, box e2)
|
||||
zir::BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::Gt(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::Gt(box e1, box e2)
|
||||
zir::BooleanExpression::FieldGt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::Ge(box e1, box e2) => {
|
||||
typed_absy::BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
zir::BooleanExpression::Ge(box e1, box e2)
|
||||
zir::BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
zir::BooleanExpression::UintGe(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
|
@ -714,25 +764,21 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.map(|member| {
|
||||
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.sum();
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
s[offset].clone().try_into().unwrap()
|
||||
}
|
||||
typed_absy::BooleanExpression::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_field_expression(index);
|
||||
let index = f.fold_uint_expression(index);
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => array
|
||||
[i.to_dec_string().parse::<usize>().unwrap()]
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -755,9 +801,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
match e {
|
||||
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
|
||||
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
|
||||
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Uint(bitwidth))[0]
|
||||
.id
|
||||
.clone(),
|
||||
flatten_identifier_rec(
|
||||
f.fold_name(id),
|
||||
&typed_absy::types::ConcreteType::Uint(bitwidth),
|
||||
)[0]
|
||||
.id
|
||||
.clone(),
|
||||
),
|
||||
typed_absy::UExpressionInner::Add(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left);
|
||||
|
@ -771,6 +820,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
|
||||
zir::UExpressionInner::Sub(box left, box right)
|
||||
}
|
||||
typed_absy::UExpressionInner::FloorSub(..) => unreachable!(),
|
||||
typed_absy::UExpressionInner::Mult(box left, box right) => {
|
||||
let left = f.fold_uint_expression(left);
|
||||
let right = f.fold_uint_expression(right);
|
||||
|
@ -827,11 +877,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
typed_absy::UExpressionInner::Neg(box e) => {
|
||||
let bitwidth = e.bitwidth();
|
||||
|
||||
f.fold_uint_expression(typed_absy::UExpression::sub(
|
||||
typed_absy::UExpressionInner::Value(0).annotate(bitwidth),
|
||||
e,
|
||||
))
|
||||
.into_inner()
|
||||
f.fold_uint_expression(typed_absy::UExpressionInner::Value(0).annotate(bitwidth) - e)
|
||||
.into_inner()
|
||||
}
|
||||
|
||||
typed_absy::UExpressionInner::Pos(box e) => {
|
||||
|
@ -844,16 +891,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::UExpressionInner::Select(box array, box index) => {
|
||||
let array = f.fold_array_expression(array);
|
||||
let index = f.fold_field_expression(index);
|
||||
let index = f.fold_uint_expression(index);
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
match index {
|
||||
zir::FieldElementExpression::Number(i) => {
|
||||
let e: zir::UExpression<_> = array[i.to_dec_string().parse::<usize>().unwrap()]
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
match index.into_inner() {
|
||||
zir::UExpressionInner::Value(i) => {
|
||||
let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap();
|
||||
e.into_inner()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
|
@ -867,11 +909,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
let offset: usize = members
|
||||
.iter()
|
||||
.take_while(|member| member.id != id)
|
||||
.map(|member| member.ty.get_primitive_count())
|
||||
.map(|member| {
|
||||
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
|
||||
.unwrap()
|
||||
.get_primitive_count()
|
||||
})
|
||||
.sum();
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap();
|
||||
|
||||
res.into_inner()
|
||||
|
@ -893,14 +937,18 @@ pub fn fold_function<'ast, T: Field>(
|
|||
arguments: fun
|
||||
.arguments
|
||||
.into_iter()
|
||||
.flat_map(|a| f.fold_parameter(a))
|
||||
.flat_map(|a| f.fold_declaration_parameter(a))
|
||||
.collect(),
|
||||
statements: fun
|
||||
.statements
|
||||
.into_iter()
|
||||
.flat_map(|s| f.fold_statement(s))
|
||||
.collect(),
|
||||
signature: fun.signature.into(),
|
||||
signature: typed_absy::types::ConcreteSignature::try_from(
|
||||
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
.into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -908,41 +956,46 @@ pub fn fold_array_expression<'ast, T: Field>(
|
|||
f: &mut Flattener<T>,
|
||||
e: typed_absy::ArrayExpression<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner())
|
||||
let size = match e.size().into_inner() {
|
||||
typed_absy::UExpressionInner::Value(v) => v,
|
||||
_ => unreachable!(),
|
||||
} as usize;
|
||||
f.fold_array_expression_inner(
|
||||
&typed_absy::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(),
|
||||
size,
|
||||
e.into_inner(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn fold_struct_expression<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
e: typed_absy::StructExpression<'ast, T>,
|
||||
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||
f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner())
|
||||
}
|
||||
|
||||
pub fn fold_function_symbol<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
s: typed_absy::TypedFunctionSymbol<'ast, T>,
|
||||
) -> zir::ZirFunctionSymbol<'ast, T> {
|
||||
match s {
|
||||
typed_absy::TypedFunctionSymbol::Here(fun) => {
|
||||
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
|
||||
}
|
||||
typed_absy::TypedFunctionSymbol::There(key, module) => {
|
||||
zir::ZirFunctionSymbol::There(f.fold_function_key(key), module)
|
||||
} // by default, do not fold modules recursively
|
||||
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
|
||||
}
|
||||
f.fold_struct_expression_inner(
|
||||
&typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(),
|
||||
e.into_inner(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn fold_program<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
p: typed_absy::TypedProgram<'ast, T>,
|
||||
mut p: typed_absy::TypedProgram<'ast, T>,
|
||||
) -> zir::ZirProgram<'ast, T> {
|
||||
let main_module = p.modules.remove(&p.main).unwrap();
|
||||
|
||||
let main_function = main_module
|
||||
.functions
|
||||
.into_iter()
|
||||
.find(|(key, _)| key.id == "main")
|
||||
.unwrap()
|
||||
.1;
|
||||
|
||||
let main_function = match main_function {
|
||||
typed_absy::TypedFunctionSymbol::Here(f) => f,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
zir::ZirProgram {
|
||||
modules: p
|
||||
.modules
|
||||
.into_iter()
|
||||
.map(|(module_id, module)| (module_id, f.fold_module(module)))
|
||||
.collect(),
|
||||
main: p.main,
|
||||
main: f.fold_function(main_function),
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -4,69 +4,93 @@
|
|||
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
|
||||
//! @date 2018
|
||||
|
||||
mod bounds_checker;
|
||||
mod flat_propagation;
|
||||
mod flatten_complex_types;
|
||||
mod inline;
|
||||
mod propagate_unroll;
|
||||
mod propagation;
|
||||
mod redefinition;
|
||||
mod return_binder;
|
||||
mod reducer;
|
||||
mod uint_optimizer;
|
||||
mod unconstrained_vars;
|
||||
mod unroll;
|
||||
mod variable_read_remover;
|
||||
mod variable_write_remover;
|
||||
|
||||
use self::bounds_checker::BoundsChecker;
|
||||
use self::flatten_complex_types::Flattener;
|
||||
use self::inline::Inliner;
|
||||
use self::propagate_unroll::PropagatedUnroller;
|
||||
use self::propagation::Propagator;
|
||||
use self::redefinition::RedefinitionOptimizer;
|
||||
use self::return_binder::ReturnBinder;
|
||||
use self::reducer::reduce_program;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||
use self::variable_read_remover::VariableReadRemover;
|
||||
use self::variable_write_remover::VariableWriteRemover;
|
||||
use crate::flat_absy::FlatProg;
|
||||
use crate::ir::Prog;
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use crate::typed_absy::{abi::Abi, TypedProgram};
|
||||
use crate::zir::ZirProgram;
|
||||
use std::fmt;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub trait Analyse {
|
||||
fn analyse(self) -> Self;
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Reducer(self::reducer::Error),
|
||||
OutOfBounds(self::bounds_checker::Error),
|
||||
Propagation(self::propagation::Error),
|
||||
}
|
||||
|
||||
impl From<self::reducer::Error> for Error {
|
||||
fn from(e: self::reducer::Error) -> Self {
|
||||
Error::Reducer(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<self::bounds_checker::Error> for Error {
|
||||
fn from(e: bounds_checker::Error) -> Self {
|
||||
Error::OutOfBounds(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<self::propagation::Error> for Error {
|
||||
fn from(e: propagation::Error) -> Self {
|
||||
Error::Propagation(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Error::Reducer(e) => write!(f, "{}", e),
|
||||
Error::OutOfBounds(e) => write!(f, "{}", e),
|
||||
Error::Propagation(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> TypedProgram<'ast, T> {
|
||||
pub fn analyse(self) -> ZirProgram<'ast, T> {
|
||||
// propagated unrolling
|
||||
let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!("{}", e));
|
||||
pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
|
||||
let r = reduce_program(self).map_err(Error::from)?;
|
||||
|
||||
// return binding
|
||||
let r = ReturnBinder::bind(r);
|
||||
|
||||
// inline
|
||||
let r = Inliner::inline(r);
|
||||
let abi = r.abi();
|
||||
|
||||
// propagate
|
||||
let r = Propagator::propagate(r);
|
||||
|
||||
let r = Propagator::propagate(r).map_err(Error::from)?;
|
||||
// optimize redefinitions
|
||||
let r = RedefinitionOptimizer::optimize(r);
|
||||
|
||||
// remove assignment to variable index
|
||||
let r = VariableWriteRemover::apply(r);
|
||||
|
||||
// remove variable access to complex types
|
||||
let r = VariableReadRemover::apply(r);
|
||||
|
||||
// check array accesses are in bounds
|
||||
let r = BoundsChecker::check(r).map_err(Error::from)?;
|
||||
// convert to zir, removing complex types
|
||||
let zir = Flattener::flatten(r);
|
||||
|
||||
// optimize uint expressions
|
||||
let zir = UintOptimizer::optimize(zir);
|
||||
|
||||
zir
|
||||
Ok((zir, abi))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,7 +102,6 @@ impl<T: Field> Analyse for FlatProg<T> {
|
|||
|
||||
impl<T: Field> Analyse for Prog<T> {
|
||||
fn analyse(self) -> Self {
|
||||
let r = UnconstrainedVariableDetector::detect(self);
|
||||
r
|
||||
UnconstrainedVariableDetector::detect(self)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,236 +0,0 @@
|
|||
//! Module containing iterative unrolling in order to unroll nested loops with variable bounds
|
||||
//!
|
||||
//! For example:
|
||||
//! ```zokrates
|
||||
//! for field i in 0..5 do
|
||||
//! for field j in i..5 do
|
||||
//! //
|
||||
//! endfor
|
||||
//! endfor
|
||||
//! ```
|
||||
//!
|
||||
//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop
|
||||
//!
|
||||
//! This module does exactly that:
|
||||
//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant
|
||||
//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet*
|
||||
//! - unroll again, this time the 5 inner loops all have constant bounds
|
||||
//!
|
||||
//! In the case that a loop bound cannot be reduced to a constant, we detect it by noticing that the unroll does
|
||||
//! not make progress anymore.
|
||||
|
||||
use crate::static_analysis::propagation::Propagator;
|
||||
use crate::static_analysis::unroll::{Output, Unroller};
|
||||
use crate::typed_absy::TypedProgram;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct PropagatedUnroller;
|
||||
|
||||
impl PropagatedUnroller {
|
||||
pub fn unroll<'ast, T: Field>(
|
||||
p: TypedProgram<'ast, T>,
|
||||
) -> Result<TypedProgram<'ast, T>, &'static str> {
|
||||
let mut blocked_at = None;
|
||||
|
||||
// unroll a first time, retrieving whether the unroll is complete
|
||||
let mut unrolled = Unroller::unroll(p);
|
||||
|
||||
loop {
|
||||
// conditions to exit the loop
|
||||
unrolled = match unrolled {
|
||||
Output::Complete(p) => return Ok(p),
|
||||
Output::Incomplete(next, index) => {
|
||||
if Some(index) == blocked_at {
|
||||
return Err("Loop unrolling failed. This happened because a loop bound is not constant");
|
||||
} else {
|
||||
// update the index where we blocked
|
||||
blocked_at = Some(index);
|
||||
|
||||
// propagate
|
||||
let propagated = Propagator::propagate_verbose(next);
|
||||
|
||||
// unroll
|
||||
Unroller::unroll(propagated)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::typed_absy::types::{FunctionKey, Signature};
|
||||
use crate::typed_absy::*;
|
||||
use zokrates_field::Bn128Field;
|
||||
|
||||
#[test]
|
||||
fn detect_non_constant_bound() {
|
||||
let loops = vec![TypedStatement::For(
|
||||
Variable::field_element("i"),
|
||||
FieldElementExpression::Identifier("i".into()),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
vec![],
|
||||
)];
|
||||
|
||||
let statements = loops;
|
||||
|
||||
let p = TypedProgram {
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
FunctionKey::with_id("main"),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
signature: Signature::new(),
|
||||
statements,
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
main: "main".into(),
|
||||
};
|
||||
|
||||
assert!(PropagatedUnroller::unroll(p).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn for_loop() {
|
||||
// for field i in 0..2
|
||||
// for field j in i..2
|
||||
// field foo = i + j
|
||||
|
||||
// should be unrolled to
|
||||
// i_0 = 0
|
||||
// j_0 = 0
|
||||
// foo_0 = i_0 + j_0
|
||||
// j_1 = 1
|
||||
// foo_1 = i_0 + j_1
|
||||
// i_1 = 1
|
||||
// j_2 = 1
|
||||
// foo_2 = i_1 + j_1
|
||||
|
||||
let s = TypedStatement::For(
|
||||
Variable::field_element("i"),
|
||||
FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
vec![TypedStatement::For(
|
||||
Variable::field_element("j"),
|
||||
FieldElementExpression::Identifier("i".into()),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
vec![
|
||||
TypedStatement::Declaration(Variable::field_element("foo")),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element("foo")),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier("i".into()),
|
||||
box FieldElementExpression::Identifier("j".into()),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
)],
|
||||
);
|
||||
|
||||
let expected_statements = vec![
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("i").version(0),
|
||||
)),
|
||||
FieldElementExpression::Number(Bn128Field::from(0)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("j").version(0),
|
||||
)),
|
||||
FieldElementExpression::Number(Bn128Field::from(0)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("foo").version(0),
|
||||
)),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
|
||||
box FieldElementExpression::Identifier(Identifier::from("j").version(0)),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("j").version(1),
|
||||
)),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("foo").version(1),
|
||||
)),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
|
||||
box FieldElementExpression::Identifier(Identifier::from("j").version(1)),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("i").version(1),
|
||||
)),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("j").version(2),
|
||||
)),
|
||||
FieldElementExpression::Number(Bn128Field::from(1)).into(),
|
||||
),
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::field_element(
|
||||
Identifier::from("foo").version(2),
|
||||
)),
|
||||
FieldElementExpression::Add(
|
||||
box FieldElementExpression::Identifier(Identifier::from("i").version(1)),
|
||||
box FieldElementExpression::Identifier(Identifier::from("j").version(2)),
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
];
|
||||
|
||||
let p = TypedProgram {
|
||||
modules: vec![(
|
||||
"main".into(),
|
||||
TypedModule {
|
||||
functions: vec![(
|
||||
FunctionKey::with_id("main"),
|
||||
TypedFunctionSymbol::Here(TypedFunction {
|
||||
arguments: vec![],
|
||||
signature: Signature::new(),
|
||||
statements: vec![s],
|
||||
}),
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
main: "main".into(),
|
||||
};
|
||||
|
||||
let statements = match PropagatedUnroller::unroll(p).unwrap().modules
|
||||
[std::path::Path::new("main")]
|
||||
.functions[&FunctionKey::with_id("main")]
|
||||
.clone()
|
||||
{
|
||||
TypedFunctionSymbol::Here(f) => f.statements,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
assert_eq!(statements, expected_statements);
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -68,6 +68,6 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast> {
|
|||
}
|
||||
|
||||
fn fold_name(&mut self, s: Identifier<'ast>) -> Identifier<'ast> {
|
||||
self.identifiers.get(&s).map(|r| r.clone()).unwrap_or(s)
|
||||
self.identifiers.get(&s).cloned().unwrap_or(s)
|
||||
}
|
||||
}
|
||||
|
|
235
zokrates_core/src/static_analysis/reducer/inline.rs
Normal file
235
zokrates_core/src/static_analysis/reducer/inline.rs
Normal file
|
@ -0,0 +1,235 @@
|
|||
// The inlining phase takes a call site (fun::<gen>(args)) and inlines it:
|
||||
|
||||
// Given:
|
||||
// ```
|
||||
// def foo<n>(field a) -> field:
|
||||
// a = a
|
||||
// n = n
|
||||
// return a
|
||||
// ```
|
||||
//
|
||||
// The call site
|
||||
// ```
|
||||
// foo::<42>(x)
|
||||
// ```
|
||||
//
|
||||
// Becomes
|
||||
// ```
|
||||
// # Call foo::<42> with a_0 := x
|
||||
// n_0 = 42
|
||||
// a_1 = a_0
|
||||
// n_1 = n_0
|
||||
// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1
|
||||
|
||||
// Notes:
|
||||
// - The body of the function is in SSA form
|
||||
// - The return value(s) are assigned to internal variables
|
||||
|
||||
use crate::embed::FlatEmbed;
|
||||
use crate::static_analysis::reducer::Output;
|
||||
use crate::static_analysis::reducer::ShallowTransformer;
|
||||
use crate::static_analysis::reducer::Versions;
|
||||
use crate::typed_absy::types::ConcreteGenericsAssignment;
|
||||
use crate::typed_absy::CoreIdentifier;
|
||||
use crate::typed_absy::Identifier;
|
||||
use crate::typed_absy::TypedAssignee;
|
||||
use crate::typed_absy::{
|
||||
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature,
|
||||
Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression,
|
||||
UExpressionInner, Variable,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub enum InlineError<'ast, T> {
|
||||
Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>),
|
||||
Flat(
|
||||
FlatEmbed,
|
||||
Vec<u32>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
),
|
||||
NonConstant(
|
||||
DeclarationFunctionKey<'ast>,
|
||||
Vec<Option<UExpression<'ast, T>>>,
|
||||
Vec<TypedExpression<'ast, T>>,
|
||||
Vec<Type<'ast, T>>,
|
||||
),
|
||||
}
|
||||
|
||||
fn get_canonical_function<'ast, T: Field>(
|
||||
function_key: DeclarationFunctionKey<'ast>,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) {
|
||||
match program
|
||||
.modules
|
||||
.get(&function_key.module)
|
||||
.unwrap()
|
||||
.functions
|
||||
.iter()
|
||||
.find(|(key, _)| function_key == **key)
|
||||
.unwrap()
|
||||
{
|
||||
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program),
|
||||
(key, s) => (key.clone(), s.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
type InlineResult<'ast, T> = Result<
|
||||
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
|
||||
InlineError<'ast, T>,
|
||||
>;
|
||||
|
||||
pub fn inline_call<'a, 'ast, T: Field>(
|
||||
k: DeclarationFunctionKey<'ast>,
|
||||
generics: Vec<Option<UExpression<'ast, T>>>,
|
||||
arguments: Vec<TypedExpression<'ast, T>>,
|
||||
output_types: Vec<Type<'ast, T>>,
|
||||
program: &TypedProgram<'ast, T>,
|
||||
versions: &'a mut Versions<'ast>,
|
||||
) -> InlineResult<'ast, T> {
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use crate::typed_absy::Typed;
|
||||
|
||||
// we try to get concrete values for explicit generics
|
||||
let generics_values: Vec<Option<u32>> = generics
|
||||
.iter()
|
||||
.map(|g| {
|
||||
g.as_ref()
|
||||
.map(|g| match g.as_inner() {
|
||||
UExpressionInner::Value(v) => Ok(*v as u32),
|
||||
_ => Err(()),
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
.map_err(|_| {
|
||||
InlineError::NonConstant(
|
||||
k.clone(),
|
||||
generics.clone(),
|
||||
arguments.clone(),
|
||||
output_types.clone(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// we infer a signature based on inputs and outputs
|
||||
// this is where we could handle explicit annotations
|
||||
let inferred_signature = Signature::new()
|
||||
.generics(generics.clone())
|
||||
.inputs(arguments.iter().map(|a| a.get_type()).collect())
|
||||
.outputs(output_types.clone());
|
||||
|
||||
// we try to get concrete values for the whole signature. if this fails we should propagate again
|
||||
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
return Err(InlineError::NonConstant(
|
||||
k,
|
||||
generics,
|
||||
arguments,
|
||||
output_types,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let (decl_key, symbol) = get_canonical_function(k.clone(), program);
|
||||
|
||||
// get an assignment of generics for this call site
|
||||
let assignment: ConcreteGenericsAssignment<'ast> = k
|
||||
.signature
|
||||
.specialize(generics_values, &inferred_signature)
|
||||
.map_err(|_| {
|
||||
InlineError::Generic(
|
||||
k.clone(),
|
||||
ConcreteFunctionKey {
|
||||
module: decl_key.module.clone(),
|
||||
id: decl_key.id,
|
||||
signature: inferred_signature.clone(),
|
||||
},
|
||||
)
|
||||
})?;
|
||||
|
||||
let f = match symbol {
|
||||
TypedFunctionSymbol::Here(f) => Ok(f),
|
||||
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
|
||||
e,
|
||||
e.generics(&assignment),
|
||||
arguments.clone(),
|
||||
output_types,
|
||||
)),
|
||||
_ => unreachable!(),
|
||||
}?;
|
||||
|
||||
assert_eq!(f.arguments.len(), arguments.len());
|
||||
|
||||
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
|
||||
Output::Complete(v) => (v, None),
|
||||
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
|
||||
};
|
||||
|
||||
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone());
|
||||
|
||||
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
|
||||
.arguments
|
||||
.into_iter()
|
||||
.zip(inferred_signature.inputs.clone())
|
||||
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
|
||||
.zip(arguments.clone())
|
||||
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
|
||||
.statements
|
||||
.into_iter()
|
||||
.partition(|s| !matches!(s, TypedStatement::Return(..)));
|
||||
|
||||
assert_eq!(returns.len(), 1);
|
||||
|
||||
let returns = match returns.pop().unwrap() {
|
||||
TypedStatement::Return(e) => e,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let res: Vec<ConcreteVariable<'ast>> = inferred_signature
|
||||
.outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, t)| {
|
||||
ConcreteVariable::with_id_and_type(
|
||||
Identifier::from(CoreIdentifier::Call(i)).version(
|
||||
*versions
|
||||
.entry(CoreIdentifier::Call(i).clone())
|
||||
.and_modify(|e| *e += 1) // if it was already declared, we increment
|
||||
.or_insert(0),
|
||||
),
|
||||
t.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let expressions: Vec<TypedExpression<_>> = res
|
||||
.iter()
|
||||
.map(|v| TypedExpression::from(Variable::from(v.clone())))
|
||||
.collect();
|
||||
|
||||
assert_eq!(res.len(), returns.len());
|
||||
|
||||
let output_bindings: Vec<TypedStatement<'ast, T>> = res
|
||||
.into_iter()
|
||||
.zip(returns)
|
||||
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
|
||||
.collect();
|
||||
|
||||
let pop_log = TypedStatement::PopCallLog;
|
||||
|
||||
let statements: Vec<_> = std::iter::once(call_log)
|
||||
.chain(input_bindings)
|
||||
.chain(statements)
|
||||
.chain(output_bindings)
|
||||
.chain(std::iter::once(pop_log))
|
||||
.collect();
|
||||
|
||||
Ok(incomplete_data
|
||||
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
|
||||
.unwrap_or_else(|| Output::Complete((statements, expressions))))
|
||||
}
|
1635
zokrates_core/src/static_analysis/reducer/mod.rs
Normal file
1635
zokrates_core/src/static_analysis/reducer/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
1024
zokrates_core/src/static_analysis/reducer/shallow_ssa.rs
Normal file
1024
zokrates_core/src/static_analysis/reducer/shallow_ssa.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,58 +0,0 @@
|
|||
use crate::typed_absy::folder::fold_statement;
|
||||
use crate::typed_absy::identifier::CoreIdentifier;
|
||||
use crate::typed_absy::*;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub struct ReturnBinder;
|
||||
|
||||
impl ReturnBinder {
|
||||
pub fn bind<'ast, T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
ReturnBinder {}.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Return(exprs) => {
|
||||
let ret_identifiers: Vec<Identifier<'ast>> = (0..exprs.len())
|
||||
.map(|i| CoreIdentifier::Internal("RETURN", i).into())
|
||||
.collect();
|
||||
|
||||
let ret_expressions: Vec<TypedExpression<'ast, T>> = exprs
|
||||
.iter()
|
||||
.zip(ret_identifiers.iter())
|
||||
.map(|(e, i)| match e.get_type() {
|
||||
Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(),
|
||||
Type::Boolean => BooleanExpression::Identifier(i.clone()).into(),
|
||||
Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone())
|
||||
.annotate(*array_type.ty, array_type.size)
|
||||
.into(),
|
||||
Type::Struct(struct_type) => StructExpressionInner::Identifier(i.clone())
|
||||
.annotate(struct_type)
|
||||
.into(),
|
||||
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
|
||||
.annotate(bitwidth)
|
||||
.into(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
exprs
|
||||
.into_iter()
|
||||
.zip(ret_identifiers.iter())
|
||||
.map(|(e, i)| {
|
||||
TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(Variable::with_id_and_type(
|
||||
i.clone(),
|
||||
e.get_type(),
|
||||
)),
|
||||
e,
|
||||
)
|
||||
})
|
||||
.chain(std::iter::once(TypedStatement::Return(ret_expressions)))
|
||||
.collect()
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue