1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

merge dev, implement inference for neg and pos, fix conclicts

This commit is contained in:
schaeff 2021-04-01 17:34:05 +02:00
commit 9d13b4129d
203 changed files with 12873 additions and 7237 deletions

View file

@ -4,6 +4,7 @@ jobs:
build: build:
docker: docker:
- image: zokrates/env:latest - image: zokrates/env:latest
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
@ -28,6 +29,7 @@ jobs:
test: test:
docker: docker:
- image: zokrates/env:latest - image: zokrates/env:latest
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:
@ -42,6 +44,9 @@ jobs:
- run: - run:
name: Check format name: Check format
command: cargo fmt --all -- --check command: cargo fmt --all -- --check
- run:
name: Run clippy
command: cargo clippy
- run: - run:
name: Build name: Build
command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./build.sh command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./build.sh
@ -80,6 +85,7 @@ jobs:
docker: docker:
- image: zokrates/env:latest - image: zokrates/env:latest
- image: trufflesuite/ganache-cli:next - image: trufflesuite/ganache-cli:next
resource_class: large
steps: steps:
- checkout - checkout
- run: - run:

74
Cargo.lock generated
View file

@ -133,7 +133,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e8cb28c2137af1ef058aa59616db3f7df67dbb70bf2be4ee6920008cc30d98c" checksum = "3e8cb28c2137af1ef058aa59616db3f7df67dbb70bf2be4ee6920008cc30d98c"
dependencies = [ dependencies = [
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -145,7 +145,7 @@ dependencies = [
"num-bigint 0.4.0", "num-bigint 0.4.0",
"num-traits 0.2.14", "num-traits 0.2.14",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -205,7 +205,7 @@ checksum = "5ac3d78c750b01f5df5b2e76d106ed31487a93b3868f14a7f0eb3a74f45e1d8a"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -653,7 +653,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d" checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d"
dependencies = [ dependencies = [
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -664,7 +664,7 @@ checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -765,7 +765,7 @@ checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
"synstructure", "synstructure",
] ]
@ -809,7 +809,7 @@ dependencies = [
"num-traits 0.2.14", "num-traits 0.2.14",
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -1059,9 +1059,9 @@ dependencies = [
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.49" version = "0.3.50"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc15e39392125075f60c95ba416f5381ff6c3a948ff02ab12464715adf56c821" checksum = "2d99f9e3e84b8f67f846ef5b4cbbc3b1c29f6c759fcbce6f01aa0e73d932a24c"
dependencies = [ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
@ -1074,9 +1074,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.91" version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7" checksum = "56d855069fafbb9b344c0f962150cd2c1187975cb1c22c1522c240d8c4986714"
[[package]] [[package]]
name = "libgit2-sys" name = "libgit2-sys"
@ -1370,7 +1370,7 @@ dependencies = [
"pest_meta", "pest_meta",
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -1773,7 +1773,7 @@ checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
] ]
[[package]] [[package]]
@ -1866,9 +1866,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.64" version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f" checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
@ -1883,7 +1883,7 @@ checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
"unicode-xid 0.2.1", "unicode-xid 0.2.1",
] ]
@ -2106,9 +2106,9 @@ checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]] [[package]]
name = "wasm-bindgen" name = "wasm-bindgen"
version = "0.2.72" version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fe8f61dba8e5d645a4d8132dc7a0a66861ed5e1045d2c0ed940fab33bac0fbe" checksum = "83240549659d187488f91f33c0f8547cbfef0b2088bc470c116d1d260ef623d9"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"wasm-bindgen-macro", "wasm-bindgen-macro",
@ -2116,24 +2116,24 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-backend" name = "wasm-bindgen-backend"
version = "0.2.72" version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "046ceba58ff062da072c7cb4ba5b22a37f00a302483f7e2a6cdc18fedbdc1fd3" checksum = "ae70622411ca953215ca6d06d3ebeb1e915f0f6613e3b495122878d7ebec7dae"
dependencies = [ dependencies = [
"bumpalo", "bumpalo",
"lazy_static", "lazy_static",
"log", "log",
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
[[package]] [[package]]
name = "wasm-bindgen-futures" name = "wasm-bindgen-futures"
version = "0.4.22" version = "0.4.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73157efb9af26fb564bb59a009afd1c7c334a44db171d280690d0c3faaec3468" checksum = "81b8b767af23de6ac18bf2168b690bed2902743ddf0fb39252e36f9e2bfc63ea"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"js-sys", "js-sys",
@ -2143,9 +2143,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro" name = "wasm-bindgen-macro"
version = "0.2.72" version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ef9aa01d36cda046f797c57959ff5f3c615c9cc63997a8d545831ec7976819b" checksum = "3e734d91443f177bfdb41969de821e15c516931c3c3db3d318fa1b68975d0f6f"
dependencies = [ dependencies = [
"quote 1.0.9", "quote 1.0.9",
"wasm-bindgen-macro-support", "wasm-bindgen-macro-support",
@ -2153,28 +2153,28 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro-support" name = "wasm-bindgen-macro-support"
version = "0.2.72" version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96eb45c1b2ee33545a813a92dbb53856418bf7eb54ab34f7f7ff1448a5b3735d" checksum = "d53739ff08c8a68b0fdbcd54c372b8ab800b1449ab3c9d706503bc7dd1621b2c"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
[[package]] [[package]]
name = "wasm-bindgen-shared" name = "wasm-bindgen-shared"
version = "0.2.72" version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7148f4696fb4960a346eaa60bbfb42a1ac4ebba21f750f75fc1375b098d5ffa" checksum = "d9a543ae66aa233d14bb765ed9af4a33e81b8b58d1584cf1b47ff8cd0b9e4489"
[[package]] [[package]]
name = "wasm-bindgen-test" name = "wasm-bindgen-test"
version = "0.3.22" version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f002ea97b5abdb19aafd48cbb5a0a7f6931cf36ea05a0a46ccc95d9f4c2cf43" checksum = "e972e914de63aa53bd84865e54f5c761bd274d48e5be3a6329a662c0386aa67a"
dependencies = [ dependencies = [
"console_error_panic_hook", "console_error_panic_hook",
"js-sys", "js-sys",
@ -2186,9 +2186,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-test-macro" name = "wasm-bindgen-test-macro"
version = "0.3.22" version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10a6c0bd3933daf64c78fc25a7452530f79fa7e21f77fa03d608d1e988a66735" checksum = "ea6153a8f9bf24588e9f25c87223414fff124049f68d3a442a0f0eab4768a8b6"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
@ -2196,9 +2196,9 @@ dependencies = [
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.49" version = "0.3.50"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fe19d70f5dacc03f6e46777213facae5ac3801575d56ca6cbd4c93dcd12310" checksum = "a905d57e488fec8861446d3393670fb50d27a262344013181c2cdf9fff5481be"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"wasm-bindgen", "wasm-bindgen",
@ -2252,7 +2252,7 @@ checksum = "c3f369ddb18862aba61aa49bf31e74d29f0f162dec753063200e1dc084345d16"
dependencies = [ dependencies = [
"proc-macro2 1.0.24", "proc-macro2 1.0.24",
"quote 1.0.9", "quote 1.0.9",
"syn 1.0.64", "syn 1.0.67",
"synstructure", "synstructure",
] ]

7
asserts.zok Normal file
View file

@ -0,0 +1,7 @@
def id<N>() -> u32:
return N
def main():
assert(id::<5>() == 5)
assert(id::<6>() == 6)
return

View file

@ -0,0 +1 @@
Introduce constant generics for `u32` values. Introduce literal inference

View file

@ -0,0 +1 @@
Make embed functions generic, enabling unpacking to any width at minimal cost

11
example.zok Normal file
View file

@ -0,0 +1,11 @@
def foo<N>(field[N] x) -> field[N]:
return x
def bar<N>(field[N] x) -> field[N]:
field[N] r = x
return r
def main(field[3] x) -> field[2]:
field[2] z = foo(x)[0..2]
return bar(z)

13
scripts/benchmark.sh Executable file
View file

@ -0,0 +1,13 @@
#!/bin/bash
# Usage: benchmark.sh <command>
# For MacOS: install gtime with homebrew `brew install gnu-time`
cmd=$*
format="mem=%KK rss=%MK elapsed=%E cpu=%P cpu.sys=%S inputs=%I outputs=%O"
if command -v gtime; then
gtime -f "$format" $cmd
else
/usr/bin/time -f "$format" $cmd
fi

View file

@ -17,7 +17,7 @@ impl<T: From<usize>> Encode<T> for Inputs<T> {
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt; use std::fmt;
use zokrates_core::typed_absy::{Type, UBitwidth}; use zokrates_core::typed_absy::types::{ConcreteType, UBitwidth};
use zokrates_field::Field; use zokrates_field::Field;
@ -94,18 +94,20 @@ impl<T: Field> fmt::Display for Value<T> {
} }
impl<T: Field> Value<T> { impl<T: Field> Value<T> {
fn check(self, ty: Type) -> Result<CheckedValue<T>, String> { fn check(self, ty: ConcreteType) -> Result<CheckedValue<T>, String> {
match (self, ty) { match (self, ty) {
(Value::Field(f), Type::FieldElement) => Ok(CheckedValue::Field(f)), (Value::Field(f), ConcreteType::FieldElement) => Ok(CheckedValue::Field(f)),
(Value::U8(f), Type::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)), (Value::U8(f), ConcreteType::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)),
(Value::U16(f), Type::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)), (Value::U16(f), ConcreteType::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)),
(Value::U32(f), Type::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)), (Value::U32(f), ConcreteType::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)),
(Value::Boolean(b), Type::Boolean) => Ok(CheckedValue::Boolean(b)), (Value::Boolean(b), ConcreteType::Boolean) => Ok(CheckedValue::Boolean(b)),
(Value::Array(a), Type::Array(array_type)) => { (Value::Array(a), ConcreteType::Array(array_type)) => {
if a.len() != array_type.size { let size = array_type.size;
if a.len() != size as usize {
Err(format!( Err(format!(
"Expected array of size {}, found array of size {}", "Expected array of size {}, found array of size {}",
array_type.size, size,
a.len() a.len()
)) ))
} else { } else {
@ -116,15 +118,16 @@ impl<T: Field> Value<T> {
Ok(CheckedValue::Array(a)) Ok(CheckedValue::Array(a))
} }
} }
(Value::Struct(mut s), Type::Struct(members)) => { (Value::Struct(mut s), ConcreteType::Struct(struc)) => {
if s.len() != members.len() { if s.len() != struc.members_count() {
Err(format!( Err(format!(
"Expected {} member(s), found {}", "Expected {} member(s), found {}",
members.len(), struc.members_count(),
s.len() s.len()
)) ))
} else { } else {
let s = members let s = struc
.members
.into_iter() .into_iter()
.map(|member| { .map(|member| {
s.remove(&member.id) s.remove(&member.id)
@ -167,7 +170,7 @@ impl<T: From<usize>> Encode<T> for CheckedValue<T> {
} }
impl<T: Field> Decode<T> for CheckedValues<T> { impl<T: Field> Decode<T> for CheckedValues<T> {
type Expected = Vec<Type>; type Expected = Vec<ConcreteType>;
fn decode(raw: Vec<T>, expected: Self::Expected) -> Self { fn decode(raw: Vec<T>, expected: Self::Expected) -> Self {
CheckedValues( CheckedValues(
@ -185,23 +188,24 @@ impl<T: Field> Decode<T> for CheckedValues<T> {
} }
impl<T: Field> Decode<T> for CheckedValue<T> { impl<T: Field> Decode<T> for CheckedValue<T> {
type Expected = Type; type Expected = ConcreteType;
fn decode(raw: Vec<T>, expected: Self::Expected) -> Self { fn decode(raw: Vec<T>, expected: Self::Expected) -> Self {
let mut raw = raw; let mut raw = raw;
match expected { match expected {
Type::FieldElement => CheckedValue::Field(raw.pop().unwrap()), ConcreteType::Int => unreachable!(),
Type::Uint(UBitwidth::B8) => CheckedValue::U8( ConcreteType::FieldElement => CheckedValue::Field(raw.pop().unwrap()),
ConcreteType::Uint(UBitwidth::B8) => CheckedValue::U8(
u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(), u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
), ),
Type::Uint(UBitwidth::B16) => CheckedValue::U16( ConcreteType::Uint(UBitwidth::B16) => CheckedValue::U16(
u16::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(), u16::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
), ),
Type::Uint(UBitwidth::B32) => CheckedValue::U32( ConcreteType::Uint(UBitwidth::B32) => CheckedValue::U32(
u32::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(), u32::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
), ),
Type::Boolean => { ConcreteType::Boolean => {
let v = raw.pop().unwrap(); let v = raw.pop().unwrap();
CheckedValue::Boolean(if v == 0.into() { CheckedValue::Boolean(if v == 0.into() {
false false
@ -211,12 +215,12 @@ impl<T: Field> Decode<T> for CheckedValue<T> {
unreachable!() unreachable!()
}) })
} }
Type::Array(array_type) => CheckedValue::Array( ConcreteType::Array(array_type) => CheckedValue::Array(
raw.chunks(array_type.ty.get_primitive_count()) raw.chunks(array_type.ty.get_primitive_count())
.map(|c| CheckedValue::decode(c.to_vec(), *array_type.ty.clone())) .map(|c| CheckedValue::decode(c.to_vec(), *array_type.ty.clone()))
.collect(), .collect(),
), ),
Type::Struct(members) => CheckedValue::Struct( ConcreteType::Struct(members) => CheckedValue::Struct(
members members
.into_iter() .into_iter()
.scan(0, |state, member| { .scan(0, |state, member| {
@ -247,9 +251,9 @@ impl<T: Field> TryFrom<serde_json::Value> for Values<T> {
match v { match v {
serde_json::Value::Array(a) => a serde_json::Value::Array(a) => a
.into_iter() .into_iter()
.map(|v| Value::try_from(v)) .map(Value::try_from)
.collect::<Result<_, _>>() .collect::<Result<_, _>>()
.map(|v| Values(v)), .map(Values),
v => Err(format!("Expected an array of values, found `{}`", v)), v => Err(format!("Expected an array of values, found `{}`", v)),
} }
} }
@ -259,20 +263,22 @@ impl<T: Field> TryFrom<serde_json::Value> for Value<T> {
type Error = String; type Error = String;
fn try_from(v: serde_json::Value) -> Result<Value<T>, Self::Error> { fn try_from(v: serde_json::Value) -> Result<Value<T>, Self::Error> {
match v { match v {
serde_json::Value::String(s) => T::try_from_dec_str(&s) serde_json::Value::String(s) => {
.map(|v| Value::Field(v)) T::try_from_dec_str(&s)
.or_else(|_| match s.len() { .map(Value::Field)
4 => u8::from_str_radix(&s[2..], 16) .or_else(|_| match s.len() {
.map(|v| Value::U8(v)) 4 => u8::from_str_radix(&s[2..], 16)
.map_err(|_| format!("Expected u8 value, found {}", s)), .map(Value::U8)
6 => u16::from_str_radix(&s[2..], 16) .map_err(|_| format!("Expected u8 value, found {}", s)),
.map(|v| Value::U16(v)) 6 => u16::from_str_radix(&s[2..], 16)
.map_err(|_| format!("Expected u16 value, found {}", s)), .map(Value::U16)
10 => u32::from_str_radix(&s[2..], 16) .map_err(|_| format!("Expected u16 value, found {}", s)),
.map(|v| Value::U32(v)) 10 => u32::from_str_radix(&s[2..], 16)
.map_err(|_| format!("Expected u32 value, found {}", s)), .map(Value::U32)
_ => Err(format!("Cannot parse {} to any type", s)), .map_err(|_| format!("Expected u32 value, found {}", s)),
}), _ => Err(format!("Cannot parse {} to any type", s)),
})
}
serde_json::Value::Bool(b) => Ok(Value::Boolean(b)), serde_json::Value::Bool(b) => Ok(Value::Boolean(b)),
serde_json::Value::Number(n) => Err(format!( serde_json::Value::Number(n) => Err(format!(
"Value `{}` isn't allowed, did you mean `\"{}\"`?", "Value `{}` isn't allowed, did you mean `\"{}\"`?",
@ -280,14 +286,14 @@ impl<T: Field> TryFrom<serde_json::Value> for Value<T> {
)), )),
serde_json::Value::Array(a) => a serde_json::Value::Array(a) => a
.into_iter() .into_iter()
.map(|v| Value::try_from(v)) .map(Value::try_from)
.collect::<Result<_, _>>() .collect::<Result<_, _>>()
.map(|v| Value::Array(v)), .map(Value::Array),
serde_json::Value::Object(o) => o serde_json::Value::Object(o) => o
.into_iter() .into_iter()
.map(|(k, v)| Value::try_from(v).map(|v| (k, v))) .map(|(k, v)| Value::try_from(v).map(|v| (k, v)))
.collect::<Result<Map<_, _>, _>>() .collect::<Result<Map<_, _>, _>>()
.map(|v| Value::Struct(v)), .map(Value::Struct),
v => Err(format!("Value `{}` isn't allowed", v)), v => Err(format!("Value `{}` isn't allowed", v)),
} }
} }
@ -320,10 +326,13 @@ impl<T: Field> Into<serde_json::Value> for CheckedValues<T> {
fn parse<T: Field>(s: &str) -> Result<Values<T>, Error> { fn parse<T: Field>(s: &str) -> Result<Values<T>, Error> {
let json_values: serde_json::Value = let json_values: serde_json::Value =
serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?; serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?;
Values::try_from(json_values).map_err(|e| Error::Conversion(e)) Values::try_from(json_values).map_err(Error::Conversion)
} }
pub fn parse_strict<T: Field>(s: &str, types: Vec<Type>) -> Result<CheckedValues<T>, Error> { pub fn parse_strict<T: Field>(
s: &str,
types: Vec<ConcreteType>,
) -> Result<CheckedValues<T>, Error> {
let parsed = parse(s)?; let parsed = parse(s)?;
if parsed.0.len() != types.len() { if parsed.0.len() != types.len() {
return Err(Error::Type(format!( return Err(Error::Type(format!(
@ -338,7 +347,7 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<Type>) -> Result<CheckedValues
.zip(types.into_iter()) .zip(types.into_iter())
.map(|(v, ty)| v.check(ty)) .map(|(v, ty)| v.check(ty))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.map_err(|e| Error::Type(e))?; .map_err(Error::Type)?;
Ok(CheckedValues(checked)) Ok(CheckedValues(checked))
} }
@ -403,14 +412,19 @@ mod tests {
mod strict { mod strict {
use super::*; use super::*;
use zokrates_core::typed_absy::types::{StructMember, StructType}; use zokrates_core::typed_absy::types::{
ConcreteStructMember, ConcreteStructType, ConcreteType,
};
#[test] #[test]
fn fields() { fn fields() {
let s = r#"["1", "2"]"#; let s = r#"["1", "2"]"#;
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>(s, vec![Type::FieldElement, Type::FieldElement]) parse_strict::<Bn128Field>(
.unwrap(), s,
vec![ConcreteType::FieldElement, ConcreteType::FieldElement]
)
.unwrap(),
CheckedValues(vec![ CheckedValues(vec![
CheckedValue::Field(1.into()), CheckedValue::Field(1.into()),
CheckedValue::Field(2.into()) CheckedValue::Field(2.into())
@ -422,7 +436,8 @@ mod tests {
fn bools() { fn bools() {
let s = "[true, false]"; let s = "[true, false]";
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>(s, vec![Type::Boolean, Type::Boolean]).unwrap(), parse_strict::<Bn128Field>(s, vec![ConcreteType::Boolean, ConcreteType::Boolean])
.unwrap(),
CheckedValues(vec![ CheckedValues(vec![
CheckedValue::Boolean(true), CheckedValue::Boolean(true),
CheckedValue::Boolean(false) CheckedValue::Boolean(false)
@ -434,7 +449,11 @@ mod tests {
fn array() { fn array() {
let s = "[[true, false]]"; let s = "[[true, false]]";
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>(s, vec![Type::array(Type::Boolean, 2)]).unwrap(), parse_strict::<Bn128Field>(
s,
vec![ConcreteType::array((ConcreteType::Boolean, 2usize))]
)
.unwrap(),
CheckedValues(vec![CheckedValue::Array(vec![ CheckedValues(vec![CheckedValue::Array(vec![
CheckedValue::Boolean(true), CheckedValue::Boolean(true),
CheckedValue::Boolean(false) CheckedValue::Boolean(false)
@ -448,10 +467,13 @@ mod tests {
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>( parse_strict::<Bn128Field>(
s, s,
vec![Type::Struct(StructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![StructMember::new("a".into(), Type::FieldElement)] vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))] ))]
) )
.unwrap(), .unwrap(),
@ -466,10 +488,13 @@ mod tests {
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>( parse_strict::<Bn128Field>(
s, s,
vec![Type::Struct(StructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![StructMember::new("a".into(), Type::FieldElement)] vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))] ))]
) )
.unwrap_err(), .unwrap_err(),
@ -480,10 +505,13 @@ mod tests {
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>( parse_strict::<Bn128Field>(
s, s,
vec![Type::Struct(StructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![StructMember::new("a".into(), Type::FieldElement)] vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))] ))]
) )
.unwrap_err(), .unwrap_err(),
@ -494,10 +522,13 @@ mod tests {
assert_eq!( assert_eq!(
parse_strict::<Bn128Field>( parse_strict::<Bn128Field>(
s, s,
vec![Type::Struct(StructType::new( vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(), "".into(),
"".into(), "".into(),
vec![StructMember::new("a".into(), Type::FieldElement)] vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))] ))]
) )
.unwrap_err(), .unwrap_err(),

View file

@ -12,6 +12,7 @@
- [Control flow](language/control_flow.md) - [Control flow](language/control_flow.md)
- [Imports](language/imports.md) - [Imports](language/imports.md)
- [Comments](language/comments.md) - [Comments](language/comments.md)
- [Generics](language/generics.md)
- [Macros](language/macros.md) - [Macros](language/macros.md)
- [Toolbox](toolbox/index.md) - [Toolbox](toolbox/index.md)

View file

@ -12,6 +12,12 @@ Arguments are passed by value.
{{#include ../../../zokrates_cli/examples/book/side_effects.zok}} {{#include ../../../zokrates_cli/examples/book/side_effects.zok}}
``` ```
Generic paramaters, if any, must be compile-time constants. They are inferred by the compiler if that is possible, but can also be provided explicitly.
```zokrates
{{#include ../../../zokrates_cli/examples/book/generic_call.zok}}
```
### If-expressions ### If-expressions
An if-expression allows you to branch your code depending on a boolean condition. An if-expression allows you to branch your code depending on a boolean condition.
@ -28,7 +34,7 @@ For loops are available with the following syntax:
{{#include ../../../zokrates_cli/examples/book/for.zok}} {{#include ../../../zokrates_cli/examples/book/for.zok}}
``` ```
The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. They can depend on generic parameters.
### Assertions ### Assertions

View file

@ -7,7 +7,14 @@ A function has to be declared at the top level before it is called.
``` ```
A function's signature has to be explicitly provided. A function's signature has to be explicitly provided.
Functions can return many values by providing them as a comma-separated list.
A function can be generic over any number of values of type `u32`.
```zokrates
{{#include ../../../zokrates_cli/examples/book/generic_function_declaration.zok}}
```
Functions can return multiple values by providing them as a comma-separated list.
```zokrates ```zokrates
{{#include ../../../zokrates_cli/examples/book/multi_return.zok}} {{#include ../../../zokrates_cli/examples/book/multi_return.zok}}

View file

@ -0,0 +1,7 @@
## Generics
ZoKrates supports code that is generic over constants of the `u32` type. No specific keyword is used: the compiler determines if the generic parameters are indeed constant at compile time. Here's an example of generic code in ZoKrates:
```zokrates
{{#include ../../../zokrates_cli/examples/book/generics.zok}}
```

View file

@ -8,7 +8,7 @@ ZoKrates currently exposes two primitive types and two complex types:
This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers. This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers.
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/reference/proving_schemes.html#alt_bn128) curve supported by Ethereum. As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/toolbox/proving_schemes.html#alt_bn128) curve supported by Ethereum.
While `field` values mostly behave like unsigned integers, one should keep in mind that they overflow at `p` and not some power of 2, so that we have: While `field` values mostly behave like unsigned integers, one should keep in mind that they overflow at `p` and not some power of 2, so that we have:
@ -32,13 +32,23 @@ Similarly to booleans, unsigned integer inputs of the main function only accept
The division operation calculates the standard floor division for integers. The `%` operand can be used to obtain the remainder. The division operation calculates the standard floor division for integers. The `%` operand can be used to obtain the remainder.
### Numeric inference
In the case of decimal literals like `42`, the compiler tries to find the appropriate type (`field`, `u8`, `u16` or `u32`) depending on the context. If it cannot converge to a single option, an error is returned. This means that there is no default type for decimal literals.
All operations between literals have the semantics of the infered type.
```zokrates
{{#include ../../../zokrates_cli/examples/book/numeric_inference.zok}}
```
## Complex Types ## Complex Types
ZoKrates provides two complex types: arrays and structs. ZoKrates provides two complex types: arrays and structs.
### Arrays ### Arrays
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. For more details on generic array sizes, see [constant generics](/language/generics.html)
Arrays can contain elements of any type and have arbitrary dimensions. Arrays can contain elements of any type and have arbitrary dimensions.
The following example code shows examples of how to use arrays: The following example code shows examples of how to use arrays:

View file

@ -11,5 +11,5 @@ fn export_stdlib() {
let out_dir = env::var("OUT_DIR").unwrap(); let out_dir = env::var("OUT_DIR").unwrap();
let mut options = CopyOptions::new(); let mut options = CopyOptions::new();
options.overwrite = true; options.overwrite = true;
copy_items(&vec!["tests/contract"], out_dir, &options).unwrap(); copy_items(&["tests/contract"], out_dir, &options).unwrap();
} }

View file

@ -1,7 +1,7 @@
def main() -> field: def main() -> u32:
field[3] a = [1, 2, 3] u32[3] a = [1, 2, 3]
field c = 0 u32 c = 0
for field i in 0..3 do for u32 i in 0..3 do
c = c + a[i] c = c + a[i]
endfor endfor
return c return c

View file

@ -1,7 +1,7 @@
def main() -> (field[3]): def main() -> (u32[3]):
field[3] a = [1, 2, 3] u32[3] a = [1, 2, 3]
field[3] c = [4, 5, 6] u32[3] c = [4, 5, 6]
for field i in 0..3 do for u32 i in 0..3 do
c[i] = c[i] + a[i] c[i] = c[i] + a[i]
endfor endfor
return c return c

View file

@ -3,7 +3,7 @@ def main(bool[3] a) -> (field[3]):
a[1] = true || a[2] a[1] = true || a[2]
a[2] = a[0] a[2] = a[0]
field[3] result = [0; 3] field[3] result = [0; 3]
for field i in 0..3 do for u32 i in 0..3 do
result[i] = if a[i] then 33 else 0 fi result[i] = if a[i] then 33 else 0 fi
endfor endfor
return result return result

View file

@ -1,9 +1,9 @@
def main(field[2][2][2] cube) -> field: def main(field[2][2][2] cube) -> field:
field res = 0 field res = 0
for field i in 0..2 do for u32 i in 0..2 do
for field j in 0..2 do for u32 j in 0..2 do
for field k in 0..2 do for u32 k in 0..2 do
res = res + cube[i][j][k] res = res + cube[i][j][k]
endfor endfor
endfor endfor

View file

@ -1,2 +1,2 @@
def main(field index, field[5] array) -> field: def main(u32 index, field[5] array) -> field:
return array[index] return array[index]

View file

@ -1,4 +1,4 @@
def main(field[10][10][10] a, field i, field j, field k) -> (field[3]): def main(field[10][10][10] a, u32 i, u32 j, u32 k) -> (field[3]):
a[i][j][k] = 42 a[i][j][k] = 42
field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]] field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
return b[0] return b[0]

View file

@ -0,0 +1,4 @@
def main(field a) -> field[4]:
u32 SIZE = 4
field[SIZE] res = [a; SIZE]
return res

View file

@ -1,6 +1,6 @@
def slice32from(field offset, field[2048] input) -> (field[32]): def slice32from(u32 offset, field[2048] input) -> (field[32]):
field[32] result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] field[32] result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
for field i in 0..32 do for u32 i in 0..32 do
result[i] = input[offset + i] result[i] = input[offset + i]
endfor endfor
return result return result

View file

@ -1,4 +1,4 @@
def get(field[32] array, field index) -> field: def get(field[32] array, u32 index) -> field:
return array[index] return array[index]
def main() -> field: def main() -> field:

View file

@ -5,4 +5,6 @@ def main() -> field:
field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4 field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4
field[2] d = a[1..3] // initialize an array copying a slice from `a` field[2] d = a[1..3] // initialize an array copying a slice from `a`
bool[3] e = [true, true || false, true] // initialize a boolean array bool[3] e = [true, true || false, true] // initialize a boolean array
u32 SIZE = 3
field[SIZE] f = [1, 2, 3] // initialize a field array with a size that's a compile-time constant
return a[0] + b[1] + c[2] return a[0] + b[1] + c[2]

View file

@ -1,3 +1,3 @@
def main() -> (): def main() -> ():
assert(1 == 2) assert(1f == 2f)
return return

View file

@ -1,7 +1,7 @@
def main() -> field: def main() -> u32:
field res = 0 u32 res = 0
for field i in 0..4 do for u32 i in 0..4 do
for field j in i..5 do for u32 j in i..5 do
res = res + i res = res + i
endfor endfor
endfor endfor

View file

@ -1,6 +1,6 @@
def main() -> field: def main() -> u32:
field a = 0 u32 a = 0
for field i in 0..5 do for u32 i in 0..5 do
a = a + i a = a + i
endfor endfor
// return i <- not allowed // return i <- not allowed

View file

@ -1,5 +1,5 @@
def foo() -> field: def foo(field a, field b) -> field:
return 1 return a + b
def main() -> field: def main() -> field:
return foo() return foo(1, 2)

View file

@ -0,0 +1,7 @@
def foo<N, P>() -> field[P]:
return [42; P]
def main() -> field[2]:
// `P` is inferred from the declaration of `res`, while `N` is provided explicitly
field[2] res = foo::<3, _>()
return res

View file

@ -0,0 +1,6 @@
def foo<N>() -> field[N]:
return [42; N]
def main() -> field[2]:
field[2] res = foo()
return res

View file

@ -0,0 +1,9 @@
def sum<N>(field[N] a) -> field:
field res = 0
for u32 i in 0..N do
res = res + a[i]
endfor
return res
def main(field[3] a) -> field:
return sum(a)

View file

@ -1,7 +1,7 @@
def main() -> field: def main() -> field:
field a = 2 field a = 2
// field a = 3 <- not allowed // field a = 3 <- not allowed
for field i in 0..5 do for u32 i in 0..5 do
// field a = 7 <- not allowed // field a = 7 <- not allowed
endfor endfor
return a return a

View file

@ -0,0 +1,9 @@
def main():
// `255` is infered to `255f`, and the addition happens between field elements
assert(255 + 1f == 256)
// `255` is infered to `255u8`, and the addition happens between u8
// This causes an overflow
assert(255 + 1u8 == 0)
return

View file

@ -0,0 +1,5 @@
import "EMBED/u8_to_bits" as u8_to_bits
def main(u8 x):
bool[32] b = u8_to_bits(x) // note the incorrect array length on the left
return

View file

@ -0,0 +1,8 @@
def foo<N>(field[N] a) -> bool:
field[3] b = a
return true
def main(field[1] a):
assert(foo(a))
return

View file

@ -0,0 +1,3 @@
def main():
assert([1] == [1, 2])
return

View file

@ -0,0 +1,2 @@
def main<P>(field[P] a):
return

View file

@ -0,0 +1,5 @@
def foo<P>(field[P] a, field[P] b) -> field:
return 42
def main() -> field:
return foo([1, 2], [1])

View file

@ -0,0 +1,3 @@
def main():
assert([[1]] == [1, 2])
return

View file

@ -0,0 +1,2 @@
def main<P>():
return

View file

@ -1,4 +1,4 @@
def main() -> field: def main() -> field:
for field i in 0..5 do for u32 i in 0..5 do
endfor endfor
return i return i

View file

@ -14,6 +14,6 @@ def main(field a) -> field:
assert(2 * b == a * 12 + 60) assert(2 * b == a * 12 + 60)
field c = 7 * (b + a) field c = 7 * (b + a)
assert(isEqual(c, 7 * b + 7 * a)) assert(isEqual(c, 7 * b + 7 * a))
field k = if [1, 2] == [3, 4] then 1 else 3 fi field k = if [1f, 2] == [3f, 4] then 1 else 3 fi
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}]) assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
return b + c return b + c

View file

@ -1,6 +1,10 @@
def bound(field x) -> u32:
return 41 + 1
def main(field a) -> field: def main(field a) -> field:
field x = 7 field x = 7
for field i in 0..10 do x = x + 1
for u32 i in 0..bound(x) do
// x = x + a // x = x + a
x = x + a x = x + a
endfor endfor

View file

@ -4,14 +4,14 @@ def lt(field a,field b) -> bool:
def cutoff() -> field: def cutoff() -> field:
return 31337 return 31337
def getThing(field index) -> field: def getThing(u32 index) -> field:
field[6] a = [13, 23, 43, 53, 73, 83] field[6] a = [13, 23, 43, 53, 73, 83]
return a[index] return a[index]
def cubeThing(field thing) -> field: def cubeThing(field thing) -> field:
return thing**3 return thing**3
def main(field index) -> bool: def main(u32 index) -> bool:
field thing = getThing(index) field thing = getThing(index)
thing = cubeThing(thing) thing = cubeThing(thing)
return lt(cutoff(), thing) return lt(cutoff(), thing)

View file

@ -4,9 +4,6 @@ import "ecc/babyjubjubParams" as context
from "ecc/babyjubjubParams" import BabyJubJubParams from "ecc/babyjubjubParams" import BabyJubJubParams
import "hashes/utils/256bitsDirectionHelper" as multiplex import "hashes/utils/256bitsDirectionHelper" as multiplex
def multiplex(bool selector, u32[8] left, u32[8] right) -> (u32[8]):
return if selector then right else left fi
// Merke-Tree inclusion proof for tree depth 3 using SNARK efficient pedersen hashes // Merke-Tree inclusion proof for tree depth 3 using SNARK efficient pedersen hashes
// directionSelector=> true if current digest is on the rhs of the hash // directionSelector=> true if current digest is on the rhs of the hash

View file

@ -1,9 +1,11 @@
import "utils/casts/u32_to_field" as to_field
// Binomial Coeffizient, n!/(k!*(n-k)!). // Binomial Coeffizient, n!/(k!*(n-k)!).
def fac(field x) -> field: def fac(field x) -> field:
field f = 1 field f = 1
field counter = 0 field counter = 0
for field i in 1..100 do for u32 i in 1..100 do
f = if counter == x then f else f * i fi f = if counter == x then f else f * to_field(i) fi
counter = if counter == x then counter else counter + 1 fi counter = if counter == x then counter else counter + 1 fi
endfor endfor
return f return f

View file

@ -2,7 +2,7 @@ def main() -> field:
field a = 1 + 2 + 3 field a = 1 + 2 + 3
field b = if 1 < a then 3 else a + 3 fi field b = if 1 < a then 3 else a + 3 fi
field c = if b + a == 2 then 1 else b fi field c = if b + a == 2 then 1 else b fi
for field e in 0..2 do for u32 e in 0..2 do
field g = 4 field g = 4
c = c + g c = c + g
endfor endfor

View file

@ -1,3 +1,3 @@
def main() -> field: def main() -> field:
field a = 2 u32 a = 2
return 2**(a**2 + 2) return 2**(a * 2 + 2)

View file

@ -32,26 +32,26 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva
bool res = true bool res = true
// go through the whole grid and check that all elements are valid // go through the whole grid and check that all elements are valid
for field i in 0..4 do for u32 i in 0..4 do
for field j in 0..4 do for u32 j in 0..4 do
res = res && validateInput(a[i][j]) res = res && validateInput(a[i][j])
endfor endfor
endfor endfor
// go through the 4 2x2 boxes and check that they do not contain duplicates // go through the 4 2x2 boxes and check that they do not contain duplicates
for field i in 0..1 do for u32 i in 0..1 do
for field j in 0..1 do for u32 j in 0..1 do
res = res && checkNoDuplicates(a[2*i][2*i], a[2*i][2*i + 1], a[2*i + 1][2*i], a[2*i + 1][2*i + 1]) res = res && checkNoDuplicates(a[2*i][2*i], a[2*i][2*i + 1], a[2*i + 1][2*i], a[2*i + 1][2*i + 1])
endfor endfor
endfor endfor
// go through the 4 rows and check that they do not contain duplicates // go through the 4 rows and check that they do not contain duplicates
for field i in 0..4 do for u32 i in 0..4 do
res = res && checkNoDuplicates(a[i][0], a[i][1], a[i][2], a[i][3]) res = res && checkNoDuplicates(a[i][0], a[i][1], a[i][2], a[i][3])
endfor endfor
// go through the 4 columns and check that they do not contain duplicates // go through the 4 columns and check that they do not contain duplicates
for field j in 0..4 do for u32 j in 0..4 do
res = res && checkNoDuplicates(a[0][j], a[1][j], a[2][j], a[3][j]) res = res && checkNoDuplicates(a[0][j], a[1][j], a[2][j], a[3][j])
endfor endfor

View file

@ -10,6 +10,6 @@ def isWaldo(field a, field p, field q) -> bool:
return a == p * q return a == p * q
// define all // define all
def main(field[3] a, private field index, private field p, private field q) -> bool: def main(field[3] a, private u32 index, private field p, private field q) -> bool:
// prover provides the index of Waldo // prover provides the index of Waldo
return isWaldo(a[index], p, q) return isWaldo(a[index], p, q)

View file

@ -32,10 +32,13 @@ fn cli() -> Result<(), String> {
compile::subcommand(), compile::subcommand(),
check::subcommand(), check::subcommand(),
compute_witness::subcommand(), compute_witness::subcommand(),
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
setup::subcommand(), setup::subcommand(),
export_verifier::subcommand(), export_verifier::subcommand(),
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
generate_proof::subcommand(), generate_proof::subcommand(),
print_proof::subcommand(), print_proof::subcommand(),
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
verify::subcommand()]) verify::subcommand()])
.get_matches(); .get_matches();
@ -140,7 +143,7 @@ mod tests {
let interpreter = ir::Interpreter::default(); let interpreter = ir::Interpreter::default();
let _ = interpreter let _ = interpreter
.execute(&artifacts.prog(), &vec![Bn128Field::from(0)]) .execute(&artifacts.prog(), &[Bn128Field::from(0)])
.unwrap(); .unwrap();
} }
} }
@ -169,7 +172,7 @@ mod tests {
let interpreter = ir::Interpreter::default(); let interpreter = ir::Interpreter::default();
let res = interpreter.execute(&artifacts.prog(), &vec![Bn128Field::from(0)]); let res = interpreter.execute(&artifacts.prog(), &[Bn128Field::from(0)]);
assert!(res.is_err()); assert!(res.is_err());
} }

View file

@ -19,6 +19,7 @@ lazy_static! {
.unwrap(); .unwrap();
} }
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") { pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") {
if cfg!(feature = "ark") { if cfg!(feature = "ark") {
if cfg!(feature = "bellman") { if cfg!(feature = "bellman") {
@ -26,27 +27,21 @@ pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") {
} else { } else {
&[LIBSNARK, ARK] &[LIBSNARK, ARK]
} }
} else if cfg!(feature = "bellman") {
&[BELLMAN, LIBSNARK]
} else { } else {
if cfg!(feature = "bellman") { &[LIBSNARK]
&[BELLMAN, LIBSNARK]
} else {
&[LIBSNARK]
}
} }
} else if cfg!(feature = "ark") {
if cfg!(feature = "bellman") {
&[BELLMAN, ARK]
} else {
&[ARK]
}
} else if cfg!(feature = "bellman") {
&[BELLMAN]
} else { } else {
if cfg!(feature = "ark") { &[]
if cfg!(feature = "bellman") {
&[BELLMAN, ARK]
} else {
&[ARK]
}
} else {
if cfg!(feature = "bellman") {
&[BELLMAN]
} else {
&[]
}
}
}; };
pub const BN128: &str = "bn128"; pub const BN128: &str = "bn128";

View file

@ -69,7 +69,7 @@ fn cli_check<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
format!( format!(
"{}:{}", "{}:{}",
file.strip_prefix(std::env::current_dir().unwrap()) file.strip_prefix(std::env::current_dir().unwrap())
.unwrap_or(file.as_path()) .unwrap_or_else(|_| file.as_path())
.display(), .display(),
e.value() e.value()
) )

View file

@ -92,9 +92,9 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
let fmt_error = |e: &CompileError| { let fmt_error = |e: &CompileError| {
let file = e.file().canonicalize().unwrap(); let file = e.file().canonicalize().unwrap();
format!( format!(
"{}:{}", "{}: {}",
file.strip_prefix(std::env::current_dir().unwrap()) file.strip_prefix(std::env::current_dir().unwrap())
.unwrap_or(file.as_path()) .unwrap_or_else(|_| file.as_path())
.display(), .display(),
e.value() e.value()
) )
@ -154,7 +154,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
.map_err(|why| format!("Couldn't create {}: {}", hr_output_path.display(), why))?; .map_err(|why| format!("Couldn't create {}: {}", hr_output_path.display(), why))?;
let mut hrofb = BufWriter::new(hr_output_file); let mut hrofb = BufWriter::new(hr_output_file);
write!(&mut hrofb, "{}\n", program_flattened) writeln!(&mut hrofb, "{}", program_flattened)
.map_err(|_| "Unable to write data to file".to_string())?; .map_err(|_| "Unable to write data to file".to_string())?;
hrofb hrofb
.flush() .flush()

View file

@ -8,7 +8,7 @@ use zokrates_abi::Encode;
use zokrates_core::ir; use zokrates_core::ir;
use zokrates_core::ir::ProgEnum; use zokrates_core::ir::ProgEnum;
use zokrates_core::typed_absy::abi::Abi; use zokrates_core::typed_absy::abi::Abi;
use zokrates_core::typed_absy::{Signature, Type}; use zokrates_core::typed_absy::types::{ConcreteSignature, ConcreteType};
use zokrates_field::Field; use zokrates_field::Field;
pub fn subcommand() -> App<'static, 'static> { pub fn subcommand() -> App<'static, 'static> {
@ -106,9 +106,12 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
abi.signature() abi.signature()
} }
false => Signature::new() false => ConcreteSignature::new()
.inputs(vec![Type::FieldElement; ir_prog.main.arguments.len()]) .inputs(vec![
.outputs(vec![Type::FieldElement; ir_prog.main.returns.len()]), ConcreteType::FieldElement;
ir_prog.main.arguments.len()
])
.outputs(vec![ConcreteType::FieldElement; ir_prog.main.returns.len()]),
}; };
use zokrates_abi::Inputs; use zokrates_abi::Inputs;
@ -123,8 +126,8 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
a.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string())) a.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string()))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
}) })
.unwrap_or(Ok(vec![])) .unwrap_or_else(|| Ok(vec![]))
.map(|v| Inputs::Raw(v)) .map(Inputs::Raw)
} }
// take stdin arguments // take stdin arguments
true => { true => {
@ -137,7 +140,7 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
use zokrates_abi::parse_strict; use zokrates_abi::parse_strict;
parse_strict(&input, signature.inputs) parse_strict(&input, signature.inputs)
.map(|parsed| Inputs::Abi(parsed)) .map(Inputs::Abi)
.map_err(|why| why.to_string()) .map_err(|why| why.to_string())
} }
Err(_) => Err(String::from("???")), Err(_) => Err(String::from("???")),
@ -148,10 +151,10 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
Ok(_) => { Ok(_) => {
input.retain(|x| x != '\n'); input.retain(|x| x != '\n');
input input
.split(" ") .split(' ')
.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string())) .map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string()))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.map(|v| Inputs::Raw(v)) .map(Inputs::Raw)
} }
Err(_) => Err(String::from("???")), Err(_) => Err(String::from("???")),
}, },

View file

@ -174,7 +174,7 @@ fn cli_generate_proof<T: Field, S: Scheme<T>, B: Backend<T, S>>(
.write(proof.as_bytes()) .write(proof.as_bytes())
.map_err(|why| format!("Couldn't write to {}: {}", proof_path.display(), why))?; .map_err(|why| format!("Couldn't write to {}: {}", proof_path.display(), why))?;
println!("Proof:\n{}", format!("{}", proof)); println!("Proof:\n{}", proof);
Ok(()) Ok(())
} }

View file

@ -1,7 +1,6 @@
pub mod check; pub mod check;
pub mod compile; pub mod compile;
pub mod compute_witness; pub mod compute_witness;
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
pub mod export_verifier; pub mod export_verifier;
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] #[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
pub mod generate_proof; pub mod generate_proof;

View file

@ -1,9 +1,11 @@
import "utils/casts/u32_to_field" as to_field
// Binomial Coeffizient, n!/(k!*(n-k)!). // Binomial Coeffizient, n!/(k!*(n-k)!).
def fac(field x) -> field: def fac(field x) -> field:
field f = 1 field f = 1
field counter = 0 field counter = 0
for field i in 1..100 do for u32 i in 1..100 do
f = if counter == x then f else f * i fi f = if counter == x then f else f * to_field(i) fi
counter = if counter == x then counter else counter + 1 fi counter = if counter == x then counter else counter + 1 fi
endfor endfor
return f return f

View file

@ -3,7 +3,7 @@ extern crate serde_json;
#[cfg(test)] #[cfg(test)]
mod integration { mod integration {
use assert_cli;
use serde_json::from_reader; use serde_json::from_reader;
use std::fs; use std::fs;
use std::fs::File; use std::fs::File;
@ -147,11 +147,11 @@ mod integration {
.map_err(|why| why.to_string()) .map_err(|why| why.to_string())
.unwrap(); .unwrap();
let signature = abi.signature().clone(); let signature = abi.signature();
let inputs_abi: zokrates_abi::Inputs<zokrates_field::Bn128Field> = let inputs_abi: zokrates_abi::Inputs<zokrates_field::Bn128Field> =
parse_strict(&json_input_str, signature.inputs) parse_strict(&json_input_str, signature.inputs)
.map(|parsed| zokrates_abi::Inputs::Abi(parsed)) .map(zokrates_abi::Inputs::Abi)
.map_err(|why| why.to_string()) .map_err(|why| why.to_string())
.unwrap(); .unwrap();
let inputs_raw: Vec<_> = inputs_abi let inputs_raw: Vec<_> = inputs_abi
@ -169,7 +169,7 @@ mod integration {
inline_witness_path.to_str().unwrap(), inline_witness_path.to_str().unwrap(),
]; ];
if inputs_raw.len() > 0 { if !inputs_raw.is_empty() {
compute_inline.push("-a"); compute_inline.push("-a");
for arg in &inputs_raw { for arg in &inputs_raw {
@ -202,7 +202,7 @@ mod integration {
assert_eq!(inline_witness, witness); assert_eq!(inline_witness, witness);
for line in expected_witness.as_str().split("\n") { for line in expected_witness.as_str().split('\n') {
assert!( assert!(
witness.contains(line), witness.contains(line),
"Witness generation failed for {}\n\nLine \"{}\" not found in witness", "Witness generation failed for {}\n\nLine \"{}\" not found in witness",

View file

@ -1,7 +1,6 @@
use crate::absy; use crate::absy;
use crate::imports; use crate::imports;
use num::ToPrimitive;
use num_bigint::BigUint; use num_bigint::BigUint;
use zokrates_pest_ast as pest; use zokrates_pest_ast as pest;
@ -10,14 +9,14 @@ impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
absy::Module::with_symbols( absy::Module::with_symbols(
prog.structs prog.structs
.into_iter() .into_iter()
.map(|t| absy::SymbolDeclarationNode::from(t)) .map(absy::SymbolDeclarationNode::from)
.chain( .chain(
prog.functions prog.functions
.into_iter() .into_iter()
.map(|f| absy::SymbolDeclarationNode::from(f)), .map(absy::SymbolDeclarationNode::from),
), ),
) )
.imports(prog.imports.into_iter().map(|i| absy::ImportNode::from(i))) .imports(prog.imports.into_iter().map(absy::ImportNode::from))
} }
} }
@ -31,17 +30,16 @@ impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportNode<'ast> {
.alias(import.alias.map(|a| a.span.as_str())) .alias(import.alias.map(|a| a.span.as_str()))
.span(import.span) .span(import.span)
} }
pest::ImportDirective::From(import) => imports::Import::new( pest::ImportDirective::From(import) => {
Some(import.symbol.span.as_str()), let symbol_str = import.symbol.span.as_str();
std::path::Path::new(import.source.span.as_str()),
) imports::Import::new(
.alias( Some(import.symbol.span.as_str()),
import std::path::Path::new(import.source.span.as_str()),
.alias )
.map(|a| a.span.as_str()) .alias(import.alias.map(|a| a.span.as_str()).or(Some(symbol_str)))
.or(Some(import.symbol.span.as_str())), .span(import.span)
) }
.span(import.span),
} }
} }
} }
@ -58,7 +56,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
fields: definition fields: definition
.fields .fields
.into_iter() .into_iter()
.map(|f| absy::StructDefinitionFieldNode::from(f)) .map(absy::StructDefinitionFieldNode::from)
.collect(), .collect(),
} }
.span(span.clone()); .span(span.clone());
@ -72,7 +70,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
} }
impl<'ast> From<pest::StructField<'ast>> for absy::StructDefinitionFieldNode<'ast> { impl<'ast> From<pest::StructField<'ast>> for absy::StructDefinitionFieldNode<'ast> {
fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode { fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode<'ast> {
use crate::absy::NodeValue; use crate::absy::NodeValue;
let span = field.span; let span = field.span;
@ -92,6 +90,13 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
let span = function.span; let span = function.span;
let signature = absy::UnresolvedSignature::new() let signature = absy::UnresolvedSignature::new()
.generics(
function
.generics
.into_iter()
.map(absy::ConstantGenericNode::from)
.collect(),
)
.inputs( .inputs(
function function
.parameters .parameters
@ -105,7 +110,7 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
.returns .returns
.clone() .clone()
.into_iter() .into_iter()
.map(|r| absy::UnresolvedTypeNode::from(r)) .map(absy::UnresolvedTypeNode::from)
.collect(), .collect(),
); );
@ -115,12 +120,12 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
arguments: function arguments: function
.parameters .parameters
.into_iter() .into_iter()
.map(|a| absy::ParameterNode::from(a)) .map(absy::ParameterNode::from)
.collect(), .collect(),
statements: function statements: function
.statements .statements
.into_iter() .into_iter()
.flat_map(|s| statements_from_statement(s)) .flat_map(statements_from_statement)
.collect(), .collect(),
signature, signature,
} }
@ -134,6 +139,16 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
} }
} }
impl<'ast> From<pest::IdentifierExpression<'ast>> for absy::ConstantGenericNode<'ast> {
fn from(g: pest::IdentifierExpression<'ast>) -> absy::ConstantGenericNode<'ast> {
use absy::NodeValue;
let name = g.span.as_str();
name.span(g.span)
}
}
impl<'ast> From<pest::Parameter<'ast>> for absy::ParameterNode<'ast> { impl<'ast> From<pest::Parameter<'ast>> for absy::ParameterNode<'ast> {
fn from(param: pest::Parameter<'ast>) -> absy::ParameterNode<'ast> { fn from(param: pest::Parameter<'ast>) -> absy::ParameterNode<'ast> {
use crate::absy::NodeValue; use crate::absy::NodeValue;
@ -247,7 +262,7 @@ impl<'ast> From<pest::ReturnStatement<'ast>> for absy::StatementNode<'ast> {
expressions: statement expressions: statement
.expressions .expressions
.into_iter() .into_iter()
.map(|e| absy::ExpressionNode::from(e)) .map(absy::ExpressionNode::from)
.collect(), .collect(),
} }
.span(statement.span.clone()), .span(statement.span.clone()),
@ -275,7 +290,7 @@ impl<'ast> From<pest::IterationStatement<'ast>> for absy::StatementNode<'ast> {
let statements: Vec<absy::StatementNode<'ast>> = statement let statements: Vec<absy::StatementNode<'ast>> = statement
.statements .statements
.into_iter() .into_iter()
.flat_map(|s| statements_from_statement(s)) .flat_map(statements_from_statement)
.collect(); .collect();
let var = absy::Variable::new(index, ty).span(statement.index.span); let var = absy::Variable::new(index, ty).span(statement.index.span);
@ -289,7 +304,7 @@ impl<'ast> From<pest::Expression<'ast>> for absy::ExpressionNode<'ast> {
match expression { match expression {
pest::Expression::Binary(e) => absy::ExpressionNode::from(e), pest::Expression::Binary(e) => absy::ExpressionNode::from(e),
pest::Expression::Ternary(e) => absy::ExpressionNode::from(e), pest::Expression::Ternary(e) => absy::ExpressionNode::from(e),
pest::Expression::Constant(e) => absy::ExpressionNode::from(e), pest::Expression::Literal(e) => absy::ExpressionNode::from(e),
pest::Expression::Identifier(e) => absy::ExpressionNode::from(e), pest::Expression::Identifier(e) => absy::ExpressionNode::from(e),
pest::Expression::Postfix(e) => absy::ExpressionNode::from(e), pest::Expression::Postfix(e) => absy::ExpressionNode::from(e),
pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e), pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e),
@ -458,7 +473,7 @@ impl<'ast> From<pest::InlineArrayExpression<'ast>> for absy::ExpressionNode<'ast
array array
.expressions .expressions
.into_iter() .into_iter()
.map(|e| absy::SpreadOrExpression::from(e)) .map(absy::SpreadOrExpression::from)
.collect(), .collect(),
) )
.span(array.span) .span(array.span)
@ -489,13 +504,8 @@ impl<'ast> From<pest::ArrayInitializerExpression<'ast>> for absy::ExpressionNode
use crate::absy::NodeValue; use crate::absy::NodeValue;
let value = absy::ExpressionNode::from(*initializer.value); let value = absy::ExpressionNode::from(*initializer.value);
let count: absy::ExpressionNode<'ast> = absy::ExpressionNode::from(initializer.count); let count = absy::ExpressionNode::from(*initializer.count);
let count = match count.value { absy::Expression::ArrayInitializer(box value, box count).span(initializer.span)
absy::Expression::FieldConstant(v) => v.to_usize().unwrap(),
_ => unreachable!(),
};
absy::Expression::InlineArray(vec![absy::SpreadOrExpression::Expression(value); count])
.span(initializer.span)
} }
} }
@ -529,9 +539,25 @@ impl<'ast> From<pest::PostfixExpression<'ast>> for absy::ExpressionNode<'ast> {
pest::Access::Call(a) => match acc.value { pest::Access::Call(a) => match acc.value {
absy::Expression::Identifier(_) => absy::Expression::FunctionCall( absy::Expression::Identifier(_) => absy::Expression::FunctionCall(
&id_str, &id_str,
a.expressions a.explicit_generics.map(|explicit_generics| {
explicit_generics
.values
.into_iter()
.map(|i| match i {
pest::ConstantGenericValue::Underscore(_) => None,
pest::ConstantGenericValue::Value(v) => {
Some(absy::ExpressionNode::from(v))
}
pest::ConstantGenericValue::Identifier(i) => {
Some(absy::Expression::Identifier(i.span.as_str()).span(i.span))
}
})
.collect()
}),
a.arguments
.expressions
.into_iter() .into_iter()
.map(|e| absy::ExpressionNode::from(e)) .map(absy::ExpressionNode::from)
.collect(), .collect(),
), ),
e => unimplemented!("only identifiers are callable, found \"{}\"", e), e => unimplemented!("only identifiers are callable, found \"{}\"", e),
@ -548,29 +574,63 @@ impl<'ast> From<pest::PostfixExpression<'ast>> for absy::ExpressionNode<'ast> {
} }
} }
impl<'ast> From<pest::ConstantExpression<'ast>> for absy::ExpressionNode<'ast> { impl<'ast> From<pest::DecimalLiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::ConstantExpression<'ast>) -> absy::ExpressionNode<'ast> { fn from(expression: pest::DecimalLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue; use crate::absy::NodeValue;
match expression.suffix {
Some(suffix) => match suffix {
pest::DecimalSuffix::Field(_) => absy::Expression::FieldConstant(
BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(),
),
pest::DecimalSuffix::U32(_) => absy::Expression::U32Constant(
u32::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
),
pest::DecimalSuffix::U16(_) => absy::Expression::U16Constant(
u16::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
),
pest::DecimalSuffix::U8(_) => absy::Expression::U8Constant(
u8::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
),
}
.span(expression.span),
None => absy::Expression::IntConstant(
BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(),
)
.span(expression.span),
}
}
}
impl<'ast> From<pest::HexLiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::HexLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue;
match expression.value {
pest::HexNumberExpression::U32(e) => {
absy::Expression::U32Constant(u32::from_str_radix(&e.span.as_str(), 16).unwrap())
}
pest::HexNumberExpression::U16(e) => {
absy::Expression::U16Constant(u16::from_str_radix(&e.span.as_str(), 16).unwrap())
}
pest::HexNumberExpression::U8(e) => {
absy::Expression::U8Constant(u8::from_str_radix(&e.span.as_str(), 16).unwrap())
}
}
.span(expression.span)
}
}
impl<'ast> From<pest::LiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::LiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue;
match expression { match expression {
pest::ConstantExpression::BooleanLiteral(c) => { pest::LiteralExpression::BooleanLiteral(c) => {
absy::Expression::BooleanConstant(c.value.parse().unwrap()).span(c.span) absy::Expression::BooleanConstant(c.value.parse().unwrap()).span(c.span)
} }
pest::ConstantExpression::DecimalNumber(n) => absy::Expression::FieldConstant( pest::LiteralExpression::DecimalLiteral(n) => absy::ExpressionNode::from(n),
BigUint::parse_bytes(&n.value.as_bytes(), 10).unwrap(), pest::LiteralExpression::HexLiteral(n) => absy::ExpressionNode::from(n),
)
.span(n.span),
pest::ConstantExpression::U8(n) => absy::Expression::U8Constant(
u8::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
)
.span(n.span),
pest::ConstantExpression::U16(n) => absy::Expression::U16Constant(
u16::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
)
.span(n.span),
pest::ConstantExpression::U32(n) => absy::Expression::U32Constant(
u32::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
)
.span(n.span),
} }
} }
} }
@ -611,8 +671,8 @@ impl<'ast> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast> {
} }
} }
impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode { impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode { fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode<'ast> {
use crate::absy::types::UnresolvedType; use crate::absy::types::UnresolvedType;
use crate::absy::NodeValue; use crate::absy::NodeValue;
@ -642,21 +702,7 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
t.dimensions t.dimensions
.into_iter() .into_iter()
.map(|s| match s { .map(absy::ExpressionNode::from)
pest::Expression::Constant(c) => match c {
pest::ConstantExpression::DecimalNumber(n) => {
str::parse::<usize>(&n.value).unwrap()
}
_ => unimplemented!(
"Array size should be a decimal number, found {}",
c.span().as_str()
),
},
e => unimplemented!(
"Array size should be constant, found {}",
e.span().as_str()
),
})
.rev() .rev()
.fold(None, |acc, s| match acc { .fold(None, |acc, s| match acc {
None => Some(UnresolvedType::array(inner_type.clone(), s)), None => Some(UnresolvedType::array(inner_type.clone(), s)),
@ -690,10 +736,9 @@ mod tests {
arguments: vec![], arguments: vec![],
statements: vec![absy::Statement::Return( statements: vec![absy::Statement::Return(
absy::ExpressionList { absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant(BigUint::from( expressions: vec![
42u32, absy::Expression::IntConstant(42usize.into()).into()
)) ],
.into()],
} }
.into(), .into(),
) )
@ -771,10 +816,9 @@ mod tests {
], ],
statements: vec![absy::Statement::Return( statements: vec![absy::Statement::Return(
absy::ExpressionList { absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant(BigUint::from( expressions: vec![
42u32, absy::Expression::IntConstant(42usize.into()).into()
)) ],
.into()],
} }
.into(), .into(),
) )
@ -800,7 +844,7 @@ mod tests {
use super::*; use super::*;
/// Helper method to generate the ast for `def main(private {ty} a): return` which we use to check ty /// Helper method to generate the ast for `def main(private {ty} a): return` which we use to check ty
fn wrap(ty: UnresolvedType) -> absy::Module<'static> { fn wrap(ty: UnresolvedType<'static>) -> absy::Module<'static> {
absy::Module { absy::Module {
symbols: vec![absy::SymbolDeclaration { symbols: vec![absy::SymbolDeclaration {
id: "main", id: "main",
@ -834,21 +878,31 @@ mod tests {
("bool", UnresolvedType::Boolean), ("bool", UnresolvedType::Boolean),
( (
"field[2]", "field[2]",
UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 2), absy::UnresolvedType::Array(
), box absy::UnresolvedType::FieldElement.mock(),
( absy::Expression::IntConstant(2usize.into()).mock(),
"field[2][3]",
UnresolvedType::Array(
box UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 3)
.mock(),
2,
), ),
), ),
( (
"bool[2][3]", "field[2][3]",
UnresolvedType::Array( absy::UnresolvedType::Array(
box UnresolvedType::Array(box UnresolvedType::Boolean.mock(), 3).mock(), box absy::UnresolvedType::Array(
2, box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::IntConstant(3usize.into()).mock(),
)
.mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
(
"bool[2][3u32]",
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::Boolean.mock(),
absy::Expression::U32Constant(3u32).mock(),
)
.mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
), ),
), ),
]; ];
@ -893,13 +947,13 @@ mod tests {
// we basically accept `()?[]*` : an optional call at first, then only array accesses // we basically accept `()?[]*` : an optional call at first, then only array accesses
let vectors = vec![ let vectors = vec![
("a", absy::Expression::Identifier("a").into()), ("a", absy::Expression::Identifier("a")),
( (
"a[3]", "a[3]",
absy::Expression::Select( absy::Expression::Select(
box absy::Expression::Identifier("a").into(), box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(3u32)).into(), absy::Expression::IntConstant(3usize.into()).into(),
) )
.into(), .into(),
), ),
@ -910,13 +964,13 @@ mod tests {
box absy::Expression::Select( box absy::Expression::Select(
box absy::Expression::Identifier("a").into(), box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(3u32)).into(), absy::Expression::IntConstant(3usize.into()).into(),
) )
.into(), .into(),
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(4u32)).into(), absy::Expression::IntConstant(4usize.into()).into(),
) )
.into(), .into(),
), ),
@ -926,11 +980,12 @@ mod tests {
absy::Expression::Select( absy::Expression::Select(
box absy::Expression::FunctionCall( box absy::Expression::FunctionCall(
"a", "a",
vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()], None,
vec![absy::Expression::IntConstant(3usize.into()).into()],
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(4u32)).into(), absy::Expression::IntConstant(4usize.into()).into(),
) )
.into(), .into(),
), ),
@ -941,17 +996,18 @@ mod tests {
box absy::Expression::Select( box absy::Expression::Select(
box absy::Expression::FunctionCall( box absy::Expression::FunctionCall(
"a", "a",
vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()], None,
vec![absy::Expression::IntConstant(3usize.into()).into()],
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(4u32)).into(), absy::Expression::IntConstant(4usize.into()).into(),
) )
.into(), .into(),
) )
.into(), .into(),
box absy::RangeOrExpression::Expression( box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(5u32)).into(), absy::Expression::IntConstant(5usize.into()).into(),
) )
.into(), .into(),
), ),
@ -993,7 +1049,7 @@ mod tests {
// For different definitions, we generate declarations // For different definitions, we generate declarations
// Case 1: `id = expr` where `expr` is not a function call // Case 1: `id = expr` where `expr` is not a function call
// This is a simple assignment, doesn't implicitely declare a variable // This is a simple assignment, doesn't implicitely declare a variable
// A `Definition` is generatedm and no `Declaration`s // A `Definition` is generated and no `Declaration`s
let definition = pest::DefinitionStatement { let definition = pest::DefinitionStatement {
lhs: vec![pest::OptionallyTypedAssignee { lhs: vec![pest::OptionallyTypedAssignee {
@ -1008,9 +1064,12 @@ mod tests {
}, },
span: span.clone(), span: span.clone(),
}], }],
expression: pest::Expression::Constant(pest::ConstantExpression::DecimalNumber( expression: pest::Expression::Literal(pest::LiteralExpression::DecimalLiteral(
pest::DecimalNumberExpression { pest::DecimalLiteralExpression {
value: String::from("42"), value: pest::DecimalNumber {
span: Span::new(&"1", 0, 1).unwrap(),
},
suffix: None,
span: span.clone(), span: span.clone(),
}, },
)), )),
@ -1049,7 +1108,11 @@ mod tests {
span: span.clone(), span: span.clone(),
}, },
accesses: vec![pest::Access::Call(pest::CallAccess { accesses: vec![pest::Access::Call(pest::CallAccess {
expressions: vec![], explicit_generics: None,
arguments: pest::Arguments {
expressions: vec![],
span: span.clone(),
},
span: span.clone(), span: span.clone(),
})], })],
span: span.clone(), span: span.clone(),
@ -1106,7 +1169,11 @@ mod tests {
span: span.clone(), span: span.clone(),
}, },
accesses: vec![pest::Access::Call(pest::CallAccess { accesses: vec![pest::Access::Call(pest::CallAccess {
expressions: vec![], explicit_generics: None,
arguments: pest::Arguments {
expressions: vec![],
span: span.clone(),
},
span: span.clone(), span: span.clone(),
})], })],
span: span.clone(), span: span.clone(),

View file

@ -105,7 +105,7 @@ impl<'ast> Module<'ast> {
} }
} }
pub type UnresolvedTypeNode = Node<UnresolvedType>; pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
/// A struct type definition /// A struct type definition
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -133,7 +133,7 @@ pub type StructDefinitionNode<'ast> = Node<StructDefinition<'ast>>;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct StructDefinitionField<'ast> { pub struct StructDefinitionField<'ast> {
pub id: Identifier<'ast>, pub id: Identifier<'ast>,
pub ty: UnresolvedTypeNode, pub ty: UnresolvedTypeNode<'ast>,
} }
impl<'ast> fmt::Display for StructDefinitionField<'ast> { impl<'ast> fmt::Display for StructDefinitionField<'ast> {
@ -216,6 +216,8 @@ impl<'ast> fmt::Debug for Module<'ast> {
} }
} }
pub type ConstantGenericNode<'ast> = Node<Identifier<'ast>>;
/// A function defined locally /// A function defined locally
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct Function<'ast> { pub struct Function<'ast> {
@ -224,13 +226,26 @@ pub struct Function<'ast> {
/// Vector of statements that are executed when running the function /// Vector of statements that are executed when running the function
pub statements: Vec<StatementNode<'ast>>, pub statements: Vec<StatementNode<'ast>>,
/// function signature /// function signature
pub signature: UnresolvedSignature, pub signature: UnresolvedSignature<'ast>,
} }
pub type FunctionNode<'ast> = Node<Function<'ast>>; pub type FunctionNode<'ast> = Node<Function<'ast>>;
impl<'ast> fmt::Display for Function<'ast> { impl<'ast> fmt::Display for Function<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.signature.generics.is_empty() {
write!(
f,
"<{}>",
self.signature
.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!( write!(
f, f,
"({}):\n{}", "({}):\n{}",
@ -294,6 +309,7 @@ impl<'ast> fmt::Display for Assignee<'ast> {
} }
/// A statement in a `Function` /// A statement in a `Function`
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub enum Statement<'ast> { pub enum Statement<'ast> {
Return(ExpressionListNode<'ast>), Return(ExpressionListNode<'ast>),
@ -319,9 +335,9 @@ impl<'ast> fmt::Display for Statement<'ast> {
Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
Statement::Assertion(ref e) => write!(f, "assert({})", e), Statement::Assertion(ref e) => write!(f, "assert({})", e),
Statement::For(ref var, ref start, ref stop, ref list) => { Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {} in {}..{} do\n", var, start, stop)?; writeln!(f, "for {} in {}..{} do", var, start, stop)?;
for l in list { for l in list {
write!(f, "\t\t{}\n", l)?; writeln!(f, "\t\t{}", l)?;
} }
write!(f, "\tendfor") write!(f, "\tendfor")
} }
@ -348,9 +364,9 @@ impl<'ast> fmt::Debug for Statement<'ast> {
} }
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e), Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
Statement::For(ref var, ref start, ref stop, ref list) => { Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?; writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?;
for l in list { for l in list {
write!(f, "\t\t{:?}\n", l)?; writeln!(f, "\t\t{:?}", l)?;
} }
write!(f, "\tendfor") write!(f, "\tendfor")
} }
@ -454,11 +470,11 @@ impl<'ast> fmt::Display for Range<'ast> {
self.from self.from
.as_ref() .as_ref()
.map(|e| e.to_string()) .map(|e| e.to_string())
.unwrap_or("".to_string()), .unwrap_or_else(|| "".to_string()),
self.to self.to
.as_ref() .as_ref()
.map(|e| e.to_string()) .map(|e| e.to_string())
.unwrap_or("".to_string()) .unwrap_or_else(|| "".to_string())
) )
} }
} }
@ -472,6 +488,7 @@ impl<'ast> fmt::Debug for Range<'ast> {
/// An expression /// An expression
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub enum Expression<'ast> { pub enum Expression<'ast> {
IntConstant(BigUint),
FieldConstant(BigUint), FieldConstant(BigUint),
BooleanConstant(bool), BooleanConstant(bool),
U8Constant(u8), U8Constant(u8),
@ -491,7 +508,11 @@ pub enum Expression<'ast> {
Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>,
Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>,
), ),
FunctionCall(FunctionIdentifier<'ast>, Vec<ExpressionNode<'ast>>), FunctionCall(
FunctionIdentifier<'ast>,
Option<Vec<Option<ExpressionNode<'ast>>>>,
Vec<ExpressionNode<'ast>>,
),
Lt(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>), Lt(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Le(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>), Le(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Eq(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>), Eq(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
@ -500,6 +521,7 @@ pub enum Expression<'ast> {
And(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>), And(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Not(Box<ExpressionNode<'ast>>), Not(Box<ExpressionNode<'ast>>),
InlineArray(Vec<SpreadOrExpression<'ast>>), InlineArray(Vec<SpreadOrExpression<'ast>>),
ArrayInitializer(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast>)>), InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast>)>),
Select(Box<ExpressionNode<'ast>>, Box<RangeOrExpression<'ast>>), Select(Box<ExpressionNode<'ast>>, Box<RangeOrExpression<'ast>>),
Member(Box<ExpressionNode<'ast>>, Box<Identifier<'ast>>), Member(Box<ExpressionNode<'ast>>, Box<Identifier<'ast>>),
@ -520,6 +542,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
Expression::U8Constant(ref i) => write!(f, "{}", i), Expression::U8Constant(ref i) => write!(f, "{}", i),
Expression::U16Constant(ref i) => write!(f, "{}", i), Expression::U16Constant(ref i) => write!(f, "{}", i),
Expression::U32Constant(ref i) => write!(f, "{}", i), Expression::U32Constant(ref i) => write!(f, "{}", i),
Expression::IntConstant(ref i) => write!(f, "{}", i),
Expression::Identifier(ref var) => write!(f, "{}", var), Expression::Identifier(ref var) => write!(f, "{}", var),
Expression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), Expression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
Expression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), Expression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
@ -535,8 +558,21 @@ impl<'ast> fmt::Display for Expression<'ast> {
"if {} then {} else {} fi", "if {} then {} else {} fi",
condition, consequent, alternative condition, consequent, alternative
), ),
Expression::FunctionCall(ref i, ref p) => { Expression::FunctionCall(ref i, ref g, ref p) => {
write!(f, "{}(", i,)?; if let Some(g) = g {
write!(
f,
"::<{}>",
g.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| "_".into()))
.collect::<Vec<_>>()
.join(", "),
)?;
}
write!(f, "{}(", i)?;
for (i, param) in p.iter().enumerate() { for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?; write!(f, "{}", param)?;
if i < p.len() - 1 { if i < p.len() - 1 {
@ -562,6 +598,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
} }
write!(f, "]") write!(f, "]")
} }
Expression::ArrayInitializer(ref e, ref count) => write!(f, "[{}; {}]", e, count),
Expression::InlineStruct(ref id, ref members) => { Expression::InlineStruct(ref id, ref members) => {
write!(f, "{} {{", id)?; write!(f, "{} {{", id)?;
for (i, (member_id, e)) in members.iter().enumerate() { for (i, (member_id, e)) in members.iter().enumerate() {
@ -587,10 +624,11 @@ impl<'ast> fmt::Display for Expression<'ast> {
impl<'ast> fmt::Debug for Expression<'ast> { impl<'ast> fmt::Debug for Expression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
Expression::U8Constant(ref i) => write!(f, "{:x}", i), Expression::U8Constant(ref i) => write!(f, "U8({:x})", i),
Expression::U16Constant(ref i) => write!(f, "{:x}", i), Expression::U16Constant(ref i) => write!(f, "U16({:x})", i),
Expression::U32Constant(ref i) => write!(f, "{:x}", i), Expression::U32Constant(ref i) => write!(f, "U32({:x})", i),
Expression::FieldConstant(ref i) => write!(f, "Num({:?})", i), Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i),
Expression::IntConstant(ref i) => write!(f, "Int({:?})", i),
Expression::Identifier(ref var) => write!(f, "Ide({})", var), Expression::Identifier(ref var) => write!(f, "Ide({})", var),
Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs), Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs),
Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs), Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
@ -606,8 +644,8 @@ impl<'ast> fmt::Debug for Expression<'ast> {
"IfElse({:?}, {:?}, {:?})", "IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative condition, consequent, alternative
), ),
Expression::FunctionCall(ref i, ref p) => { Expression::FunctionCall(ref g, ref i, ref p) => {
write!(f, "FunctionCall({:?}, (", i)?; write!(f, "FunctionCall({:?}, {:?}, (", g, i)?;
f.debug_list().entries(p.iter()).finish()?; f.debug_list().entries(p.iter()).finish()?;
write!(f, ")") write!(f, ")")
} }
@ -623,6 +661,9 @@ impl<'ast> fmt::Debug for Expression<'ast> {
f.debug_list().entries(exprs.iter()).finish()?; f.debug_list().entries(exprs.iter()).finish()?;
write!(f, "]") write!(f, "]")
} }
Expression::ArrayInitializer(ref e, ref count) => {
write!(f, "ArrayInitializer({:?}, {:?})", e, count)
}
Expression::InlineStruct(ref id, ref members) => { Expression::InlineStruct(ref id, ref members) => {
write!(f, "InlineStruct({:?}, [", id)?; write!(f, "InlineStruct({:?}, [", id)?;
f.debug_list().entries(members.iter()).finish()?; f.debug_list().entries(members.iter()).finish()?;
@ -645,21 +686,13 @@ impl<'ast> fmt::Debug for Expression<'ast> {
} }
/// A list of expressions, used in return statements /// A list of expressions, used in return statements
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq, Default)]
pub struct ExpressionList<'ast> { pub struct ExpressionList<'ast> {
pub expressions: Vec<ExpressionNode<'ast>>, pub expressions: Vec<ExpressionNode<'ast>>,
} }
pub type ExpressionListNode<'ast> = Node<ExpressionList<'ast>>; pub type ExpressionListNode<'ast> = Node<ExpressionList<'ast>>;
impl<'ast> ExpressionList<'ast> {
pub fn new() -> ExpressionList<'ast> {
ExpressionList {
expressions: vec![],
}
}
}
impl<'ast> fmt::Display for ExpressionList<'ast> { impl<'ast> fmt::Display for ExpressionList<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for (i, param) in self.expressions.iter().enumerate() { for (i, param) in self.expressions.iter().enumerate() {

View file

@ -81,7 +81,7 @@ impl<'ast> NodeValue for ExpressionList<'ast> {}
impl<'ast> NodeValue for Assignee<'ast> {} impl<'ast> NodeValue for Assignee<'ast> {}
impl<'ast> NodeValue for Statement<'ast> {} impl<'ast> NodeValue for Statement<'ast> {}
impl<'ast> NodeValue for SymbolDeclaration<'ast> {} impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
impl NodeValue for UnresolvedType {} impl<'ast> NodeValue for UnresolvedType<'ast> {}
impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {}
impl<'ast> NodeValue for StructDefinitionField<'ast> {} impl<'ast> NodeValue for StructDefinitionField<'ast> {}
impl<'ast> NodeValue for Function<'ast> {} impl<'ast> NodeValue for Function<'ast> {}
@ -92,6 +92,7 @@ impl<'ast> NodeValue for Parameter<'ast> {}
impl<'ast> NodeValue for Import<'ast> {} impl<'ast> NodeValue for Import<'ast> {}
impl<'ast> NodeValue for Spread<'ast> {} impl<'ast> NodeValue for Spread<'ast> {}
impl<'ast> NodeValue for Range<'ast> {} impl<'ast> NodeValue for Range<'ast> {}
impl<'ast> NodeValue for Identifier<'ast> {}
impl<T: PartialEq> PartialEq for Node<T> { impl<T: PartialEq> PartialEq for Node<T> {
fn eq(&self, other: &Node<T>) -> bool { fn eq(&self, other: &Node<T>) -> bool {

View file

@ -1,3 +1,4 @@
use crate::absy::ExpressionNode;
use crate::absy::UnresolvedTypeNode; use crate::absy::UnresolvedTypeNode;
use std::fmt; use std::fmt;
@ -8,15 +9,15 @@ pub type MemberId = String;
pub type UserTypeId = String; pub type UserTypeId = String;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub enum UnresolvedType { pub enum UnresolvedType<'ast> {
FieldElement, FieldElement,
Boolean, Boolean,
Uint(usize), Uint(usize),
Array(Box<UnresolvedTypeNode>, usize), Array(Box<UnresolvedTypeNode<'ast>>, ExpressionNode<'ast>),
User(UserTypeId), User(UserTypeId),
} }
impl fmt::Display for UnresolvedType { impl<'ast> fmt::Display for UnresolvedType<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
UnresolvedType::FieldElement => write!(f, "field"), UnresolvedType::FieldElement => write!(f, "field"),
@ -28,8 +29,8 @@ impl fmt::Display for UnresolvedType {
} }
} }
impl UnresolvedType { impl<'ast> UnresolvedType<'ast> {
pub fn array(ty: UnresolvedTypeNode, size: usize) -> Self { pub fn array(ty: UnresolvedTypeNode<'ast>, size: ExpressionNode<'ast>) -> Self {
UnresolvedType::Array(box ty, size) UnresolvedType::Array(box ty, size)
} }
} }
@ -39,17 +40,19 @@ pub type FunctionIdentifier<'ast> = &'ast str;
pub use self::signature::UnresolvedSignature; pub use self::signature::UnresolvedSignature;
mod signature { mod signature {
use crate::absy::ConstantGenericNode;
use std::fmt; use std::fmt;
use crate::absy::UnresolvedTypeNode; use crate::absy::UnresolvedTypeNode;
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq, Default)]
pub struct UnresolvedSignature { pub struct UnresolvedSignature<'ast> {
pub inputs: Vec<UnresolvedTypeNode>, pub generics: Vec<ConstantGenericNode<'ast>>,
pub outputs: Vec<UnresolvedTypeNode>, pub inputs: Vec<UnresolvedTypeNode<'ast>>,
pub outputs: Vec<UnresolvedTypeNode<'ast>>,
} }
impl fmt::Debug for UnresolvedSignature { impl<'ast> fmt::Debug for UnresolvedSignature<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
f, f,
@ -59,7 +62,7 @@ mod signature {
} }
} }
impl fmt::Display for UnresolvedSignature { impl<'ast> fmt::Display for UnresolvedSignature<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?; write!(f, "(")?;
for (i, t) in self.inputs.iter().enumerate() { for (i, t) in self.inputs.iter().enumerate() {
@ -79,20 +82,22 @@ mod signature {
} }
} }
impl UnresolvedSignature { impl<'ast> UnresolvedSignature<'ast> {
pub fn new() -> UnresolvedSignature { pub fn new() -> UnresolvedSignature<'ast> {
UnresolvedSignature { UnresolvedSignature::default()
inputs: vec![],
outputs: vec![],
}
} }
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode>) -> Self { pub fn generics(mut self, generics: Vec<ConstantGenericNode<'ast>>) -> Self {
self.generics = generics;
self
}
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode<'ast>>) -> Self {
self.inputs = inputs; self.inputs = inputs;
self self
} }
pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode>) -> Self { pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode<'ast>>) -> Self {
self.outputs = outputs; self.outputs = outputs;
self self
} }

View file

@ -7,21 +7,21 @@ use crate::absy::Identifier;
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub struct Variable<'ast> { pub struct Variable<'ast> {
pub id: Identifier<'ast>, pub id: Identifier<'ast>,
pub _type: UnresolvedTypeNode, pub _type: UnresolvedTypeNode<'ast>,
} }
pub type VariableNode<'ast> = Node<Variable<'ast>>; pub type VariableNode<'ast> = Node<Variable<'ast>>;
impl<'ast> Variable<'ast> { impl<'ast> Variable<'ast> {
pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode) -> Variable<'ast> { pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode<'ast>) -> Variable<'ast> {
Variable { Variable {
id: id.into(), id: id.into(),
_type: t, _type: t,
} }
} }
pub fn get_type(&self) -> UnresolvedType { pub fn get_type(&self) -> &UnresolvedType<'ast> {
self._type.value.clone() &self._type.value
} }
} }

View file

@ -9,6 +9,7 @@ use crate::imports::{self, Importer};
use crate::ir; use crate::ir;
use crate::macros; use crate::macros;
use crate::semantics::{self, Checker}; use crate::semantics::{self, Checker};
use crate::static_analysis;
use crate::static_analysis::Analyse; use crate::static_analysis::Analyse;
use crate::typed_absy::abi::Abi; use crate::typed_absy::abi::Abi;
use crate::zir::ZirProgram; use crate::zir::ZirProgram;
@ -55,6 +56,7 @@ pub enum CompileErrorInner {
MacroError(macros::Error), MacroError(macros::Error),
SemanticError(semantics::ErrorInner), SemanticError(semantics::ErrorInner),
ReadError(io::Error), ReadError(io::Error),
AnalysisError(static_analysis::Error),
} }
impl CompileErrorInner { impl CompileErrorInner {
@ -129,6 +131,12 @@ impl From<semantics::Error> for CompileError {
} }
} }
impl From<static_analysis::Error> for CompileErrorInner {
fn from(error: static_analysis::Error) -> Self {
CompileErrorInner::AnalysisError(error)
}
}
impl fmt::Display for CompileErrorInner { impl fmt::Display for CompileErrorInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
@ -137,6 +145,7 @@ impl fmt::Display for CompileErrorInner {
CompileErrorInner::SemanticError(ref e) => write!(f, "{}", e), CompileErrorInner::SemanticError(ref e) => write!(f, "{}", e),
CompileErrorInner::ReadError(ref e) => write!(f, "{}", e), CompileErrorInner::ReadError(ref e) => write!(f, "{}", e),
CompileErrorInner::ImportError(ref e) => write!(f, "{}", e), CompileErrorInner::ImportError(ref e) => write!(f, "{}", e),
CompileErrorInner::AnalysisError(ref e) => write!(f, "{}", e),
} }
} }
} }
@ -179,7 +188,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
}) })
} }
pub fn check<'ast, T: Field, E: Into<imports::Error>>( pub fn check<T: Field, E: Into<imports::Error>>(
source: String, source: String,
location: FilePath, location: FilePath,
resolver: Option<&dyn Resolver<E>>, resolver: Option<&dyn Resolver<E>>,
@ -196,19 +205,18 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
arena: &'ast Arena<String>, arena: &'ast Arena<String>,
) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> { ) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> {
let source = arena.alloc(source); let source = arena.alloc(source);
let compiled = compile_program::<T, E>(source, location.clone(), resolver, &arena)?; let compiled = compile_program::<T, E>(source, location, resolver, &arena)?;
// check semantics // check semantics
let typed_ast = Checker::check(compiled).map_err(|errors| { let typed_ast = Checker::check(compiled)
CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect()) .map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?;
})?;
let abi = typed_ast.abi(); let main_module = typed_ast.main.clone();
// analyse (unroll and constant propagation) // analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse(); typed_ast
.analyse()
Ok((typed_ast, abi)) .map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)]))
} }
pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>( pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
@ -244,7 +252,7 @@ pub fn compile_module<'ast, T: Field, E: Into<imports::Error>>(
let module_without_imports: Module = Module::from(ast); let module_without_imports: Module = Module::from(ast);
Importer::new().apply_imports::<T, E>( Importer::apply_imports::<T, E>(
module_without_imports, module_without_imports,
location.clone(), location.clone(),
resolver, resolver,
@ -380,17 +388,17 @@ struct Bar { field a }
inputs: vec![AbiInput { inputs: vec![AbiInput {
name: "f".into(), name: "f".into(),
public: true, public: true,
ty: Type::Struct(StructType::new( ty: ConcreteType::Struct(ConcreteStructType::new(
"foo".into(), "foo".into(),
"Foo".into(), "Foo".into(),
vec![StructMember { vec![ConcreteStructMember {
id: "b".into(), id: "b".into(),
ty: box Type::Struct(StructType::new( ty: box ConcreteType::Struct(ConcreteStructType::new(
"bar".into(), "bar".into(),
"Bar".into(), "Bar".into(),
vec![StructMember { vec![ConcreteStructMember {
id: "a".into(), id: "a".into(),
ty: box Type::FieldElement ty: box ConcreteType::FieldElement
}] }]
)) ))
}] }]

View file

@ -3,7 +3,9 @@ use crate::flat_absy::{
FlatVariable, FlatVariable,
}; };
use crate::solvers::Solver; use crate::solvers::Solver;
use crate::typed_absy::types::{FunctionKey, Signature, Type}; use crate::typed_absy::types::{
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType,
};
use std::collections::HashMap; use std::collections::HashMap;
use zokrates_field::{Bn128Field, Field}; use zokrates_field::{Bn128Field, Field};
@ -18,11 +20,12 @@ cfg_if::cfg_if! {
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
/// the flattening step when it can be inlined. /// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub enum FlatEmbed { pub enum FlatEmbed {
U32ToField,
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
Sha256Round, Sha256Round,
Unpack(usize), Unpack,
U8ToBits, U8ToBits,
U16ToBits, U16ToBits,
U32ToBits, U32ToBits,
@ -32,48 +35,88 @@ pub enum FlatEmbed {
} }
impl FlatEmbed { impl FlatEmbed {
pub fn signature(&self) -> Signature { pub fn signature(&self) -> DeclarationSignature<'static> {
match self { match self {
FlatEmbed::U32ToField => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]),
FlatEmbed::Unpack => DeclarationSignature::new()
.generics(vec![Some(Constant::Generic("N"))])
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
"N",
))]),
FlatEmbed::U8ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(8)])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
8usize,
))]),
FlatEmbed::U16ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(16)])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
16usize,
))]),
FlatEmbed::U32ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
32usize,
))]),
FlatEmbed::U8FromBits => DeclarationSignature::new()
.outputs(vec![DeclarationType::uint(8)])
.inputs(vec![DeclarationType::array((
DeclarationType::Boolean,
8usize,
))]),
FlatEmbed::U16FromBits => DeclarationSignature::new()
.outputs(vec![DeclarationType::uint(16)])
.inputs(vec![DeclarationType::array((
DeclarationType::Boolean,
16usize,
))]),
FlatEmbed::U32FromBits => DeclarationSignature::new()
.outputs(vec![DeclarationType::uint(32)])
.inputs(vec![DeclarationType::array((
DeclarationType::Boolean,
32usize,
))]),
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => Signature::new() FlatEmbed::Sha256Round => DeclarationSignature::new()
.inputs(vec![ .inputs(vec![
Type::array(Type::Boolean, 512), DeclarationType::array((DeclarationType::Boolean, 512usize)),
Type::array(Type::Boolean, 256), DeclarationType::array((DeclarationType::Boolean, 256usize)),
]) ])
.outputs(vec![Type::array(Type::Boolean, 256)]), .outputs(vec![DeclarationType::array((
FlatEmbed::Unpack(bitwidth) => Signature::new() DeclarationType::Boolean,
.inputs(vec![Type::FieldElement]) 256usize,
.outputs(vec![Type::array(Type::Boolean, *bitwidth)]), ))]),
FlatEmbed::U8ToBits => Signature::new()
.inputs(vec![Type::uint(8)])
.outputs(vec![Type::array(Type::Boolean, 8)]),
FlatEmbed::U16ToBits => Signature::new()
.inputs(vec![Type::uint(16)])
.outputs(vec![Type::array(Type::Boolean, 16)]),
FlatEmbed::U32ToBits => Signature::new()
.inputs(vec![Type::uint(32)])
.outputs(vec![Type::array(Type::Boolean, 32)]),
FlatEmbed::U8FromBits => Signature::new()
.outputs(vec![Type::uint(8)])
.inputs(vec![Type::array(Type::Boolean, 8)]),
FlatEmbed::U16FromBits => Signature::new()
.outputs(vec![Type::uint(16)])
.inputs(vec![Type::array(Type::Boolean, 16)]),
FlatEmbed::U32FromBits => Signature::new()
.outputs(vec![Type::uint(32)])
.inputs(vec![Type::array(Type::Boolean, 32)]),
} }
} }
pub fn key<T: Field>(&self) -> FunctionKey<'static> { pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
FunctionKey::with_id(self.id()).signature(self.signature()) let gen = self
.signature()
.generics
.into_iter()
.map(|c| match c.unwrap() {
Constant::Generic(g) => g,
_ => unreachable!(),
});
assert_eq!(gen.len(), assignment.0.len());
gen.map(|g| *assignment.0.get(&g).clone().unwrap() as u32)
.collect()
} }
pub fn id(&self) -> &'static str { pub fn id(&self) -> &'static str {
match self { match self {
FlatEmbed::U32ToField => "_U32_TO_FIELD",
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => "_SHA256_ROUND", FlatEmbed::Sha256Round => "_SHA256_ROUND",
FlatEmbed::Unpack(_) => "_UNPACK", FlatEmbed::Unpack => "_UNPACK",
FlatEmbed::U8ToBits => "_U8_TO_BITS", FlatEmbed::U8ToBits => "_U8_TO_BITS",
FlatEmbed::U16ToBits => "_U16_TO_BITS", FlatEmbed::U16ToBits => "_U16_TO_BITS",
FlatEmbed::U32ToBits => "_U32_TO_BITS", FlatEmbed::U32ToBits => "_U32_TO_BITS",
@ -84,11 +127,11 @@ impl FlatEmbed {
} }
/// Actually get the `FlatFunction` that this `FlatEmbed` represents /// Actually get the `FlatFunction` that this `FlatEmbed` represents
pub fn synthetize<T: Field>(&self) -> FlatFunction<T> { pub fn synthetize<T: Field>(&self, generics: &[u32]) -> FlatFunction<T> {
match self { match self {
#[cfg(feature = "bellman")] #[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => sha256_round(), FlatEmbed::Sha256Round => sha256_round(),
FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth), FlatEmbed::Unpack => unpack_to_bitwidth(generics[0] as usize),
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -101,7 +144,7 @@ fn flat_expression_from_vec<T: Field, E: Engine>(v: &[(usize, E::Fr)]) -> FlatEx
match v.len() { match v.len() {
0 => FlatExpression::Number(T::zero()), 0 => FlatExpression::Number(T::zero()),
1 => { 1 => {
let (key, val) = v[0].clone(); let (key, val) = v[0];
let mut res: Vec<u8> = vec![]; let mut res: Vec<u8> = vec![];
val.into_repr().write_le(&mut res).unwrap(); val.into_repr().write_le(&mut res).unwrap();
FlatExpression::Mult( FlatExpression::Mult(
@ -152,7 +195,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
let output_indices = output_indices.into_iter(); let output_indices = output_indices.into_iter();
let variable_count = r1cs.aux_count + 1; // auxiliary and ONE let variable_count = r1cs.aux_count + 1; // auxiliary and ONE
// indices of the sha256round constraint system variables // indices of the sha256round constraint system variables
let cs_indices = (0..variable_count).into_iter(); let cs_indices = 0..variable_count;
// indices of the arguments to the function // indices of the arguments to the function
// apply an offset of `variable_count` to get the indice of our dummy `input` argument // apply an offset of `variable_count` to get the indice of our dummy `input` argument
let input_argument_indices = input_indices let input_argument_indices = input_indices
@ -180,7 +223,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
); );
let input_binding_statements = let input_binding_statements =
// bind input and current_hash to inputs // bind input and current_hash to inputs
input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| { input_indices.chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
FlatStatement::Condition( FlatStatement::Condition(
FlatVariable::new(cs_index).into(), FlatVariable::new(cs_index).into(),
FlatVariable::new(argument_index).into(), FlatVariable::new(argument_index).into(),
@ -194,7 +237,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
.collect(); .collect();
// insert a directive to set the witness based on the bellman gadget and inputs // insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(FlatDirective { let directive_statement = FlatStatement::Directive(FlatDirective {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(), outputs: cs_indices.map(FlatVariable::new).collect(),
inputs: input_argument_indices inputs: input_argument_indices
.chain(current_hash_argument_indices) .chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into()) .map(|i| FlatVariable::new(i).into())
@ -224,7 +267,7 @@ fn use_variable(
) -> FlatVariable { ) -> FlatVariable {
let var = FlatVariable::new(*index); let var = FlatVariable::new(*index);
layout.insert(name, var); layout.insert(name, var);
*index = *index + 1; *index += 1;
var var
} }
@ -255,7 +298,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
let directive_inputs = vec![FlatExpression::Identifier(use_variable( let directive_inputs = vec![FlatExpression::Identifier(use_variable(
&mut layout, &mut layout,
format!("i0"), "i0".into(),
&mut counter, &mut counter,
))]; ))];
@ -268,7 +311,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
let outputs = directive_outputs let outputs = directive_outputs
.iter() .iter()
.enumerate() .enumerate()
.map(|(_, o)| FlatExpression::Identifier(o.clone())) .map(|(_, o)| FlatExpression::Identifier(*o))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// o253, o252, ... o{253 - (bit_width - 1)} are bits // o253, o252, ... o{253 - (bit_width - 1)} are bits
@ -308,7 +351,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
FlatStatement::Directive(FlatDirective { FlatStatement::Directive(FlatDirective {
inputs: directive_inputs, inputs: directive_inputs,
outputs: directive_outputs, outputs: directive_outputs,
solver: solver, solver,
}), }),
); );
@ -436,7 +479,7 @@ mod tests {
private: vec![true; 768], private: vec![true; 768],
}; };
let input = (0..512) let input: Vec<_> = (0..512)
.map(|_| 0) .map(|_| 0)
.chain((0..256).map(|_| 1)) .chain((0..256).map(|_| 1))
.map(|i| Bn128Field::from(i)) .map(|i| Bn128Field::from(i))

View file

@ -41,7 +41,7 @@ impl FlatParameter {
substitution: &HashMap<FlatVariable, FlatVariable>, substitution: &HashMap<FlatVariable, FlatVariable>,
) -> FlatParameter { ) -> FlatParameter {
FlatParameter { FlatParameter {
id: substitution.get(&self.id).unwrap().clone(), id: *substitution.get(&self.id).unwrap(),
private: self.private, private: self.private,
} }
} }

View file

@ -45,7 +45,7 @@ impl FlatVariable {
Ok(FlatVariable::public(v)) Ok(FlatVariable::public(v))
} }
None => { None => {
let mut private = s.split("_"); let mut private = s.split('_');
match private.nth(1) { match private.nth(1) {
Some(v) => { Some(v) => {
let v = v.parse().map_err(|_| s)?; let v = v.parse().map_err(|_| s)?;

View file

@ -155,23 +155,13 @@ impl<T: Field> FlatStatement<T> {
} }
} }
#[derive(Clone, Hash, Debug)] #[derive(Clone, Hash, Debug, PartialEq, Eq)]
pub struct FlatDirective<T: Field> { pub struct FlatDirective<T: Field> {
pub inputs: Vec<FlatExpression<T>>, pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<FlatVariable>, pub outputs: Vec<FlatVariable>,
pub solver: Solver, pub solver: Solver,
} }
impl<T: Field> PartialEq for FlatDirective<T> {
fn eq(&self, other: &Self) -> bool {
self.inputs.eq(&other.inputs)
&& self.outputs.eq(&other.outputs)
&& self.solver.eq(&other.solver)
}
}
impl<T: Field> Eq for FlatDirective<T> {}
impl<T: Field> FlatDirective<T> { impl<T: Field> FlatDirective<T> {
pub fn new<E: Into<FlatExpression<T>>>( pub fn new<E: Into<FlatExpression<T>>>(
outputs: Vec<FlatVariable>, outputs: Vec<FlatVariable>,
@ -249,12 +239,18 @@ impl<T: Field> FlatExpression<T> {
FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => { FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => {
x.is_linear() && y.is_linear() x.is_linear() && y.is_linear()
} }
FlatExpression::Mult(ref x, ref y) => match (x.clone(), y.clone()) { FlatExpression::Mult(ref x, ref y) => matches!(
(x.clone(), y.clone()),
(box FlatExpression::Number(_), box FlatExpression::Number(_)) (box FlatExpression::Number(_), box FlatExpression::Number(_))
| (box FlatExpression::Number(_), box FlatExpression::Identifier(_)) | (
| (box FlatExpression::Identifier(_), box FlatExpression::Number(_)) => true, box FlatExpression::Number(_),
_ => false, box FlatExpression::Identifier(_)
}, )
| (
box FlatExpression::Identifier(_),
box FlatExpression::Number(_)
)
),
} }
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -42,7 +42,7 @@ impl fmt::Display for Error {
let location = self let location = self
.pos .pos
.map(|p| format!("{}", p.0)) .map(|p| format!("{}", p.0))
.unwrap_or("?".to_string()); .unwrap_or_else(|| "?".to_string());
write!(f, "{}\n\t{}", location, self.message) write!(f, "{}\n\t{}", location, self.message)
} }
} }
@ -125,15 +125,10 @@ impl<'ast> fmt::Debug for Import<'ast> {
} }
} }
pub struct Importer {} pub struct Importer;
impl Importer { impl Importer {
pub fn new() -> Importer {
Importer {}
}
pub fn apply_imports<'ast, T: Field, E: Into<Error>>( pub fn apply_imports<'ast, T: Field, E: Into<Error>>(
&self,
destination: Module<'ast>, destination: Module<'ast>,
location: PathBuf, location: PathBuf,
resolver: Option<&dyn Resolver<E>>, resolver: Option<&dyn Resolver<E>>,
@ -179,7 +174,7 @@ impl Importer {
symbols.push( symbols.push(
SymbolDeclaration { SymbolDeclaration {
id: &alias, id: &alias,
symbol: Symbol::Flat(FlatEmbed::Unpack(T::get_required_bits())), symbol: Symbol::Flat(FlatEmbed::Unpack),
} }
.start_end(pos.0, pos.1), .start_end(pos.0, pos.1),
); );
@ -267,13 +262,15 @@ impl Importer {
let alias = import.alias.unwrap_or( let alias = import.alias.unwrap_or(
std::path::Path::new(import.source) std::path::Path::new(import.source)
.file_stem() .file_stem()
.ok_or(CompileErrors::from( .ok_or_else(|| {
CompileErrorInner::ImportError(Error::new(format!( CompileErrors::from(
"Could not determine alias for import {}", CompileErrorInner::ImportError(Error::new(format!(
import.source.display() "Could not determine alias for import {}",
))) import.source.display()
.in_file(&location), )))
))? .in_file(&location),
)
})?
.to_str() .to_str()
.unwrap(), .unwrap(),
); );
@ -335,7 +332,6 @@ impl Importer {
Ok(Module { Ok(Module {
imports: vec![], imports: vec![],
symbols, symbols,
..destination
}) })
} }
} }

View file

@ -2,49 +2,36 @@ use crate::flat_absy::FlatVariable;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::btree_map::{BTreeMap, Entry}; use std::collections::btree_map::{BTreeMap, Entry};
use std::fmt; use std::fmt;
use std::hash::Hash;
use std::ops::{Add, Div, Mul, Sub}; use std::ops::{Add, Div, Mul, Sub};
use zokrates_field::Field; use zokrates_field::Field;
#[derive(Debug, Clone, Serialize, Deserialize, Hash)] #[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct QuadComb<T> { pub struct QuadComb<T> {
pub left: LinComb<T>, pub left: LinComb<T>,
pub right: LinComb<T>, pub right: LinComb<T>,
} }
impl<T: Field> PartialEq for QuadComb<T> {
fn eq(&self, other: &Self) -> bool {
self.left.eq(&other.left) && self.right.eq(&other.right)
}
}
impl<T: Field> Eq for QuadComb<T> {}
impl<T: Field> QuadComb<T> { impl<T: Field> QuadComb<T> {
pub fn from_linear_combinations(left: LinComb<T>, right: LinComb<T>) -> Self { pub fn from_linear_combinations(left: LinComb<T>, right: LinComb<T>) -> Self {
QuadComb { left, right } QuadComb { left, right }
} }
pub fn try_linear(&self) -> Option<LinComb<T>> { pub fn try_linear(self) -> Result<LinComb<T>, Self> {
// identify (k * ~ONE) * (lincomb) and return (k * lincomb) // identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb)
// if not, error out with the input
match self.left.try_summand() {
Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => {
return Some(self.right.clone() * &coefficient);
}
_ => {}
}
match self.right.try_summand() {
Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => {
return Some(self.left.clone() * &coefficient);
}
_ => {}
}
if self.left.is_zero() || self.right.is_zero() { if self.left.is_zero() || self.right.is_zero() {
return Some(LinComb::zero()); return Ok(LinComb::zero());
} }
None match self.left.try_constant() {
Ok(coefficient) => Ok(self.right * &coefficient),
Err(left) => match self.right.try_constant() {
Ok(coefficient) => Ok(left * &coefficient),
Err(right) => Err(QuadComb::from_linear_combinations(left, right)),
},
}
} }
} }
@ -66,17 +53,9 @@ impl<T: Field> fmt::Display for QuadComb<T> {
} }
} }
#[derive(Clone, Hash, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct LinComb<T>(pub Vec<(FlatVariable, T)>); pub struct LinComb<T>(pub Vec<(FlatVariable, T)>);
impl<T: Field> PartialEq for LinComb<T> {
fn eq(&self, other: &Self) -> bool {
self.clone().into_canonical() == other.clone().into_canonical()
}
}
impl<T: Field> Eq for LinComb<T> {}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct CanonicalLinComb<T>(pub BTreeMap<FlatVariable, T>); pub struct CanonicalLinComb<T>(pub BTreeMap<FlatVariable, T>);
@ -113,36 +92,52 @@ impl<T> LinComb<T> {
} }
pub fn is_zero(&self) -> bool { pub fn is_zero(&self) -> bool {
self.0.len() == 0 self.0.is_empty()
} }
} }
impl<T: Field> LinComb<T> { impl<T: Field> LinComb<T> {
pub fn try_summand(&self) -> Option<(FlatVariable, T)> { pub fn try_constant(self) -> Result<T, Self> {
match self.0.len() { match self.0.len() {
// if the lincomb is empty, it is not reduceable to a summand // if the lincomb is empty, it is reduceable to 0
0 => None, 0 => Ok(T::zero()),
_ => { _ => {
// take the first variable in the lincomb // take the first variable in the lincomb
let first = &self.0[0].0; let first = &self.0[0].0;
self.0 if first != &FlatVariable::one() {
.iter() return Err(self);
.map(|element| { }
// all terms must contain the same variable
if self.0.iter().all(|element| element.0 == *first) {
Ok(self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1))
} else {
Err(self)
}
}
}
}
pub fn try_summand(self) -> Result<(FlatVariable, T), Self> {
match self.0.len() {
// if the lincomb is empty, it is not reduceable to a summand
0 => Err(self),
_ => {
// take the first variable in the lincomb
let first = &self.0[0].0;
if self.0.iter().all(|element|
// all terms must contain the same variable // all terms must contain the same variable
if element.0 == *first { element.0 == *first)
// if they do, return the coefficient {
Ok(&element.1) Ok((
} else { *first,
// otherwise, stop self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1),
Err(()) ))
} } else {
}) Err(self)
// collect to a Result to short circuit when we hit an error }
.collect::<Result<_, _>>()
// we didn't hit an error, do final processing. It's fine to clone here.
.map(|v: Vec<_>| (first.clone(), v.iter().fold(T::zero(), |acc, e| acc + *e)))
.ok()
} }
} }
} }
@ -207,9 +202,7 @@ impl<T: Field> fmt::Display for LinComb<T> {
false => write!( false => write!(
f, f,
"{}", "{}",
self.clone() self.0
.into_canonical()
.0
.iter() .iter()
.map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k)) .map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k))
.collect::<Vec<_>>() .collect::<Vec<_>>()
@ -251,10 +244,14 @@ impl<T: Field> Mul<&T> for LinComb<T> {
type Output = LinComb<T>; type Output = LinComb<T>;
fn mul(self, scalar: &T) -> LinComb<T> { fn mul(self, scalar: &T) -> LinComb<T> {
if scalar == &T::one() {
return self;
}
LinComb( LinComb(
self.0 self.0
.into_iter() .into_iter()
.map(|(var, coeff)| (var, coeff * scalar)) .map(|(var, coeff)| (var, coeff * scalar.clone()))
.collect(), .collect(),
) )
} }
@ -262,7 +259,8 @@ impl<T: Field> Mul<&T> for LinComb<T> {
impl<T: Field> Div<&T> for LinComb<T> { impl<T: Field> Div<&T> for LinComb<T> {
type Output = LinComb<T>; type Output = LinComb<T>;
// Clippy warns about multiplication in a method named div. It's okay, here, since we multiply with the inverse.
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, scalar: &T) -> LinComb<T> { fn div(self, scalar: &T) -> LinComb<T> {
self * &scalar.inverse_mul().unwrap() self * &scalar.inverse_mul().unwrap()
} }
@ -287,7 +285,7 @@ mod tests {
fn add() { fn add() {
let a: LinComb<Bn128Field> = FlatVariable::new(42).into(); let a: LinComb<Bn128Field> = FlatVariable::new(42).into();
let b: LinComb<Bn128Field> = FlatVariable::new(42).into(); let b: LinComb<Bn128Field> = FlatVariable::new(42).into();
let c = a + b.clone(); let c = a + b;
let expected_vec = vec![ let expected_vec = vec![
(FlatVariable::new(42), Bn128Field::from(1)), (FlatVariable::new(42), Bn128Field::from(1)),
@ -300,7 +298,7 @@ mod tests {
fn sub() { fn sub() {
let a: LinComb<Bn128Field> = FlatVariable::new(42).into(); let a: LinComb<Bn128Field> = FlatVariable::new(42).into();
let b: LinComb<Bn128Field> = FlatVariable::new(42).into(); let b: LinComb<Bn128Field> = FlatVariable::new(42).into();
let c = a - b.clone(); let c = a - b;
let expected_vec = vec![ let expected_vec = vec![
(FlatVariable::new(42), Bn128Field::from(1)), (FlatVariable::new(42), Bn128Field::from(1)),
@ -314,7 +312,7 @@ mod tests {
fn display() { fn display() {
let a: LinComb<Bn128Field> = let a: LinComb<Bn128Field> =
LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21)); LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21));
assert_eq!(&a.to_string(), "3 * _21 + 1 * _42"); assert_eq!(&a.to_string(), "1 * _42 + 3 * _21");
let zero: LinComb<Bn128Field> = LinComb::zero(); let zero: LinComb<Bn128Field> = LinComb::zero();
assert_eq!(&zero.to_string(), "0"); assert_eq!(&zero.to_string(), "0");
} }
@ -350,7 +348,7 @@ mod tests {
+ LinComb::summand(4, FlatVariable::new(33)), + LinComb::summand(4, FlatVariable::new(33)),
right: LinComb::summand(1, FlatVariable::new(21)), right: LinComb::summand(1, FlatVariable::new(21)),
}; };
assert_eq!(&a.to_string(), "(4 * _33 + 3 * _42) * (1 * _21)"); assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)");
let a: QuadComb<Bn128Field> = QuadComb { let a: QuadComb<Bn128Field> = QuadComb {
left: LinComb::zero(), left: LinComb::zero(),
right: LinComb::summand(1, FlatVariable::new(21)), right: LinComb::summand(1, FlatVariable::new(21)),
@ -371,7 +369,7 @@ mod tests {
]); ]);
assert_eq!( assert_eq!(
summand.try_summand(), summand.try_summand(),
Some((FlatVariable::new(42), Bn128Field::from(6))) Ok((FlatVariable::new(42), Bn128Field::from(6)))
); );
let not_summand = LinComb(vec![ let not_summand = LinComb(vec![
@ -379,10 +377,10 @@ mod tests {
(FlatVariable::new(42), Bn128Field::from(2)), (FlatVariable::new(42), Bn128Field::from(2)),
(FlatVariable::new(42), Bn128Field::from(3)), (FlatVariable::new(42), Bn128Field::from(3)),
]); ]);
assert_eq!(not_summand.try_summand(), None); assert!(not_summand.try_summand().is_err());
let empty: LinComb<Bn128Field> = LinComb(vec![]); let empty: LinComb<Bn128Field> = LinComb(vec![]);
assert_eq!(empty.try_summand(), None); assert!(empty.try_summand().is_err());
} }
} }
} }

View file

@ -125,7 +125,7 @@ impl<T: Field> From<FlatDirective<T>> for Directive<T> {
inputs: ds inputs: ds
.inputs .inputs
.into_iter() .into_iter()
.map(|i| QuadComb::from_flat_expression(i)) .map(QuadComb::from_flat_expression)
.collect(), .collect(),
solver: ds.solver, solver: ds.solver,
outputs: ds.outputs, outputs: ds.outputs,

View file

@ -1,10 +1,13 @@
use crate::flat_absy::flat_variable::FlatVariable; use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::Directive; use crate::ir::Directive;
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness}; use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
use crate::solvers::{Executable, Solver}; use crate::solvers::Solver;
use pairing_ce::bn256::Bn256;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::fmt; use std::fmt;
#[cfg(feature = "bellman")]
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::Field; use zokrates_field::Field;
pub type ExecutionResult<T> = Result<Witness<T>, Error>; pub type ExecutionResult<T> = Result<Witness<T>, Error>;
@ -34,13 +37,13 @@ impl Interpreter {
} }
impl Interpreter { impl Interpreter {
pub fn execute<T: Field>(&self, program: &Prog<T>, inputs: &Vec<T>) -> ExecutionResult<T> { pub fn execute<T: Field>(&self, program: &Prog<T>, inputs: &[T]) -> ExecutionResult<T> {
let main = &program.main; let main = &program.main;
self.check_inputs(&program, &inputs)?; self.check_inputs(&program, &inputs)?;
let mut witness = BTreeMap::new(); let mut witness = BTreeMap::new();
witness.insert(FlatVariable::one(), T::one()); witness.insert(FlatVariable::one(), T::one());
for (arg, value) in main.arguments.iter().zip(inputs.iter()) { for (arg, value) in main.arguments.iter().zip(inputs.iter()) {
witness.insert(arg.clone(), value.clone().into()); witness.insert(*arg, value.clone());
} }
for statement in main.statements.iter() { for statement in main.statements.iter() {
@ -48,7 +51,7 @@ impl Interpreter {
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) { Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
true => { true => {
let val = quad.evaluate(&witness).unwrap(); let val = quad.evaluate(&witness).unwrap();
witness.insert(lin.0.iter().next().unwrap().0.clone(), val); witness.insert(lin.0.get(0).unwrap().0, val);
} }
false => { false => {
let lhs_value = quad.evaluate(&witness).unwrap(); let lhs_value = quad.evaluate(&witness).unwrap();
@ -76,10 +79,10 @@ impl Interpreter {
.iter() .iter()
.map(|i| i.evaluate(&witness).unwrap()) .map(|i| i.evaluate(&witness).unwrap())
.collect(); .collect();
match d.solver.execute(&inputs) { match self.execute_solver(&d.solver, &inputs) {
Ok(res) => { Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() { for (i, o) in d.outputs.iter().enumerate() {
witness.insert(o.clone(), res[i].clone()); witness.insert(*o, res[i].clone());
} }
continue; continue;
} }
@ -107,12 +110,12 @@ impl Interpreter {
value.to_biguint() value.to_biguint()
}; };
let mut num = input.clone(); let mut num = input;
let mut res = vec![]; let mut res = vec![];
let bits = T::get_required_bits(); let bits = T::get_required_bits();
for i in (0..bits).rev() { for i in (0..bits).rev() {
if T::from(2).to_biguint().pow(i as usize) <= num { if T::from(2).to_biguint().pow(i as usize) <= num {
num = num - T::from(2).to_biguint().pow(i as usize); num -= T::from(2).to_biguint().pow(i as usize);
res.push(T::one()); res.push(T::one());
} else { } else {
res.push(T::zero()); res.push(T::zero());
@ -120,11 +123,11 @@ impl Interpreter {
} }
assert_eq!(num, T::zero().to_biguint()); assert_eq!(num, T::zero().to_biguint());
for (i, o) in d.outputs.iter().enumerate() { for (i, o) in d.outputs.iter().enumerate() {
witness.insert(o.clone(), res[i].clone()); witness.insert(*o, res[i].clone());
} }
} }
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &Vec<U>) -> Result<(), Error> { fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
if program.main.arguments.len() == inputs.len() { if program.main.arguments.len() == inputs.len() {
Ok(()) Ok(())
} else { } else {
@ -134,26 +137,136 @@ impl Interpreter {
}) })
} }
} }
pub fn execute_solver<T: Field>(
&self,
solver: &Solver,
inputs: &[T],
) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = solver.get_signature();
assert!(inputs.len() == expected_input_count);
let res = match solver {
Solver::ConditionEq => match inputs[0].is_zero() {
true => vec![T::zero(), T::one()],
false => vec![
T::one(),
T::one().checked_div(&inputs[0]).unwrap_or_else(T::one),
],
},
Solver::Bits(bit_width) => {
let mut num = inputs[0].clone();
let mut res = vec![];
for i in (0..*bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
} else {
res.push(T::zero());
}
}
res
}
Solver::Xor => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - T::from(2) * x * y]
}
Solver::Or => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - x * y]
}
// res = b * c - (2b * c - b - c) * (a)
Solver::ShaAndXorAndXorAnd => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a]
}
// res = a(b - c) + c
Solver::ShaCh => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![a * (b - c.clone()) + c]
}
Solver::Div => vec![inputs[0]
.clone()
.checked_div(&inputs[1])
.unwrap_or_else(T::one)],
Solver::EuclideanDiv => {
use num::CheckedDiv;
let n = inputs[0].clone().to_biguint();
let d = inputs[1].clone().to_biguint();
let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into());
let r = n - d * &q;
vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()]
}
#[cfg(feature = "bellman")]
Solver::Sha256Round => {
use zokrates_field::Bn128Field;
assert_eq!(T::id(), Bn128Field::id());
let i = &inputs[0..512];
let h = &inputs[512..];
let to_fr = |x: &T| {
use pairing_ce::ff::{PrimeField, ScalarEngine};
let s = x.to_dec_string();
<Bn256 as ScalarEngine>::Fr::from_str(&s).unwrap()
};
let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect();
let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect();
assert_eq!(h.len(), 256);
generate_sha256_round_witness::<Bn256>(&i, &h)
.into_iter()
.map(|x| {
use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr};
let mut res: Vec<u8> = vec![];
x.into_repr().write_le(&mut res).unwrap();
T::from_byte_vector(res)
})
.collect()
}
};
assert_eq!(res.len(), expected_output_count);
Ok(res)
}
} }
#[derive(Debug)]
pub struct EvaluationError;
impl<T: Field> LinComb<T> { impl<T: Field> LinComb<T> {
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, ()> { fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, EvaluationError> {
self.0 self.0
.iter() .iter()
.map(|(var, mult)| witness.get(var).map(|v| v.clone() * mult).ok_or(())) // get each term .map(|(var, mult)| {
witness
.get(var)
.map(|v| v.clone() * mult)
.ok_or(EvaluationError)
}) // get each term
.collect::<Result<Vec<_>, _>>() // fail if any term isn't found .collect::<Result<Vec<_>, _>>() // fail if any term isn't found
.map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum .map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum
} }
fn is_assignee<U>(&self, witness: &BTreeMap<FlatVariable, U>) -> bool { fn is_assignee<U>(&self, witness: &BTreeMap<FlatVariable, U>) -> bool {
self.0.iter().count() == 1 self.0.iter().count() == 1
&& self.0.iter().next().unwrap().1 == T::from(1) && self.0.get(0).unwrap().1 == T::from(1)
&& !witness.contains_key(&self.0.iter().next().unwrap().0) && !witness.contains_key(&self.0.get(0).unwrap().0)
} }
} }
impl<T: Field> QuadComb<T> { impl<T: Field> QuadComb<T> {
pub fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, ()> { pub fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, EvaluationError> {
let left = self.left.evaluate(&witness)?; let left = self.left.evaluate(&witness)?;
let right = self.right.evaluate(&witness)?; let right = self.right.evaluate(&witness)?;
Ok(left * right) Ok(left * right)
@ -192,3 +305,83 @@ impl fmt::Debug for Error {
write!(f, "{}", self) write!(f, "{}", self)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
mod eq_condition {
// Wanted: (Y = (X != 0) ? 1 : 0)
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
use super::*;
#[test]
fn execute() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![0];
let interpreter = Interpreter::default();
let r = interpreter
.execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
#[test]
fn execute_non_eq() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![1];
let interpreter = Interpreter::default();
let r = interpreter
.execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
}
#[test]
fn bits_of_one() {
let inputs = vec![Bn128Field::from(1)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
assert_eq!(res[253], Bn128Field::from(1));
for i in 0..253 {
assert_eq!(res[i], Bn128Field::from(0));
}
}
#[test]
fn bits_of_42() {
let inputs = vec![Bn128Field::from(42)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
assert_eq!(res[253], Bn128Field::from(0));
assert_eq!(res[252], Bn128Field::from(1));
assert_eq!(res[251], Bn128Field::from(0));
assert_eq!(res[250], Bn128Field::from(1));
assert_eq!(res[249], Bn128Field::from(0));
assert_eq!(res[248], Bn128Field::from(1));
assert_eq!(res[247], Bn128Field::from(0));
}
}

View file

@ -3,6 +3,7 @@ use crate::flat_absy::FlatVariable;
use crate::solvers::Solver; use crate::solvers::Solver;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
use std::hash::Hash;
use zokrates_field::Field; use zokrates_field::Field;
mod expression; mod expression;
@ -19,26 +20,12 @@ pub use self::serialize::ProgEnum;
pub use self::interpreter::{Error, ExecutionResult, Interpreter}; pub use self::interpreter::{Error, ExecutionResult, Interpreter};
pub use self::witness::Witness; pub use self::witness::Witness;
#[derive(Debug, Serialize, Deserialize, Clone, Hash)] #[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
pub enum Statement<T> { pub enum Statement<T> {
Constraint(QuadComb<T>, LinComb<T>), Constraint(QuadComb<T>, LinComb<T>),
Directive(Directive<T>), Directive(Directive<T>),
} }
impl<T: Field> PartialEq for Statement<T> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Statement::Constraint(l1, r1), Statement::Constraint(l2, r2)) => {
l1.eq(l2) && r1.eq(r2)
}
(Statement::Directive(d1), Statement::Directive(d2)) => d1.eq(d2),
_ => false,
}
}
}
impl<T: Field> Eq for Statement<T> {}
impl<T: Field> Statement<T> { impl<T: Field> Statement<T> {
pub fn definition<U: Into<QuadComb<T>>>(v: FlatVariable, e: U) -> Self { pub fn definition<U: Into<QuadComb<T>>>(v: FlatVariable, e: U) -> Self {
Statement::Constraint(e.into(), v.into()) Statement::Constraint(e.into(), v.into())
@ -49,23 +36,13 @@ impl<T: Field> Statement<T> {
} }
} }
#[derive(Clone, Debug, Serialize, Deserialize, Hash)] #[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct Directive<T> { pub struct Directive<T> {
pub inputs: Vec<QuadComb<T>>, pub inputs: Vec<QuadComb<T>>,
pub outputs: Vec<FlatVariable>, pub outputs: Vec<FlatVariable>,
pub solver: Solver, pub solver: Solver,
} }
impl<T: Field> PartialEq for Directive<T> {
fn eq(&self, other: &Self) -> bool {
self.inputs.eq(&other.inputs)
&& self.outputs.eq(&other.outputs)
&& self.solver.eq(&other.solver)
}
}
impl<T: Field> Eq for Directive<T> {}
impl<T: Field> fmt::Display for Directive<T> { impl<T: Field> fmt::Display for Directive<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
@ -95,7 +72,7 @@ impl<T: Field> fmt::Display for Statement<T> {
} }
} }
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
pub struct Function<T> { pub struct Function<T> {
pub id: String, pub id: String,
pub statements: Vec<Statement<T>>, pub statements: Vec<Statement<T>>,
@ -103,15 +80,6 @@ pub struct Function<T> {
pub returns: Vec<FlatVariable>, pub returns: Vec<FlatVariable>,
} }
impl<T: Field> PartialEq for Function<T> {
fn eq(&self, other: &Self) -> bool {
self.id.eq(&other.id)
&& self.statements.eq(&other.statements)
&& self.arguments.eq(&other.arguments)
&& self.returns.eq(&other.returns)
}
}
impl<T: Field> fmt::Display for Function<T> { impl<T: Field> fmt::Display for Function<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(
@ -138,27 +106,18 @@ impl<T: Field> fmt::Display for Function<T> {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
pub struct Prog<T> { pub struct Prog<T> {
pub main: Function<T>, pub main: Function<T>,
pub private: Vec<bool>, pub private: Vec<bool>,
} }
impl<T: Field> PartialEq for Prog<T> {
fn eq(&self, other: &Self) -> bool {
self.main.eq(&other.main) && self.private.eq(&other.private)
}
}
impl<T: Field> Prog<T> { impl<T: Field> Prog<T> {
pub fn constraint_count(&self) -> usize { pub fn constraint_count(&self) -> usize {
self.main self.main
.statements .statements
.iter() .iter()
.filter(|s| match s { .filter(|s| matches!(s, Statement::Constraint(..)))
Statement::Constraint(..) => true,
_ => false,
})
.count() .count()
} }

View file

@ -16,9 +16,9 @@ pub enum ProgEnum {
impl<T: Field> Prog<T> { impl<T: Field> Prog<T> {
pub fn serialize<W: Write>(&self, mut w: W) { pub fn serialize<W: Write>(&self, mut w: W) {
w.write(ZOKRATES_MAGIC).unwrap(); w.write_all(ZOKRATES_MAGIC).unwrap();
w.write(ZOKRATES_VERSION_1).unwrap(); w.write_all(ZOKRATES_VERSION_1).unwrap();
w.write(&T::id()).unwrap(); w.write_all(&T::id()).unwrap();
serialize_into(&mut w, self, Infinite).unwrap(); serialize_into(&mut w, self, Infinite).unwrap();
} }

View file

@ -19,7 +19,7 @@ impl fmt::Display for Error {
} }
} }
pub fn process_macros<'ast, T: Field>(file: File<'ast>) -> Result<File<'ast>, Error> { pub fn process_macros<T: Field>(file: File) -> Result<File, Error> {
match &file.pragma { match &file.pragma {
Some(pragma) => { Some(pragma) => {
if T::name() != pragma.curve.name { if T::name() != pragma.curve.name {

View file

@ -0,0 +1,10 @@
use crate::ir::{folder::Folder, LinComb};
use zokrates_field::Field;
pub struct Canonicalizer;
impl<T: Field> Folder<T> for Canonicalizer {
fn fold_linear_combination(&mut self, l: LinComb<T>) -> LinComb<T> {
l.into_canonical().into()
}
}

View file

@ -12,10 +12,10 @@
use crate::flat_absy::flat_variable::FlatVariable; use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::folder::*; use crate::ir::folder::*;
use crate::ir::*; use crate::ir::*;
use crate::optimizer::canonicalizer::Canonicalizer;
use crate::solvers::Solver; use crate::solvers::Solver;
use std::collections::hash_map::{Entry, HashMap}; use std::collections::hash_map::{Entry, HashMap};
use zokrates_field::Field; use zokrates_field::Field;
#[derive(Debug)] #[derive(Debug)]
pub struct DirectiveOptimizer<T: Field> { pub struct DirectiveOptimizer<T: Field> {
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>, calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
@ -37,6 +37,23 @@ impl<T: Field> DirectiveOptimizer<T> {
} }
impl<T: Field> Folder<T> for DirectiveOptimizer<T> { impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
fn fold_function(&mut self, f: Function<T>) -> Function<T> {
// in order to correcty identify duplicates, we need to first canonicalize the statements
let mut canonicalizer = Canonicalizer;
let f = Function {
statements: f
.statements
.into_iter()
.flat_map(|s| canonicalizer.fold_statement(s))
.collect(),
..f
};
fold_function(self, f)
}
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
match s { match s {
Statement::Directive(d) => { Statement::Directive(d) => {
@ -49,7 +66,7 @@ impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
} }
Entry::Occupied(e) => { Entry::Occupied(e) => {
self.substitution self.substitution
.extend(d.outputs.into_iter().zip(e.get().into_iter().cloned())); .extend(d.outputs.into_iter().zip(e.get().iter().cloned()));
vec![] vec![]
} }
} }

View file

@ -1,7 +1,8 @@
//! Module containing the `DuplicateOptimizer` to remove duplicate constraints //! Module containing the `DuplicateOptimizer` to remove duplicate constraints
use crate::ir::folder::Folder; use crate::ir::folder::*;
use crate::ir::*; use crate::ir::*;
use crate::optimizer::canonicalizer::Canonicalizer;
use std::collections::{hash_map::DefaultHasher, HashSet}; use std::collections::{hash_map::DefaultHasher, HashSet};
use zokrates_field::Field; use zokrates_field::Field;
@ -33,6 +34,22 @@ impl DuplicateOptimizer {
} }
impl<T: Field> Folder<T> for DuplicateOptimizer { impl<T: Field> Folder<T> for DuplicateOptimizer {
fn fold_function(&mut self, f: Function<T>) -> Function<T> {
// in order to correcty identify duplicates, we need to first canonicalize the statements
let mut canonicalizer = Canonicalizer;
let f = Function {
statements: f
.statements
.into_iter()
.flat_map(|s| canonicalizer.fold_statement(s))
.collect(),
..f
};
fold_function(self, f)
}
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
let hashed = hash(&s); let hashed = hash(&s);
let result = match self.seen.get(&hashed) { let result = match self.seen.get(&hashed) {
@ -120,7 +137,7 @@ mod tests {
main: Function { main: Function {
id: "main".to_string(), id: "main".to_string(),
statements: vec![ statements: vec![
constraint.clone(), constraint,
Statement::Constraint( Statement::Constraint(
QuadComb::from_linear_combinations( QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(42)), LinComb::summand(3, FlatVariable::new(42)),

View file

@ -4,6 +4,7 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr> //! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018 //! @date 2018
mod canonicalizer;
mod directive; mod directive;
mod duplicate; mod duplicate;
mod redefinition; mod redefinition;
@ -26,7 +27,6 @@ impl<T: Field> Prog<T> {
// // deduplicate directives which take the same input // // deduplicate directives which take the same input
let r = DirectiveOptimizer::optimize(r); let r = DirectiveOptimizer::optimize(r);
// remove duplicate constraints // remove duplicate constraints
let r = DuplicateOptimizer::optimize(r); DuplicateOptimizer::optimize(r)
r
} }
} }

View file

@ -40,7 +40,6 @@ use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::folder::{fold_function, Folder}; use crate::ir::folder::{fold_function, Folder};
use crate::ir::LinComb; use crate::ir::LinComb;
use crate::ir::*; use crate::ir::*;
use crate::solvers::Executable;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use zokrates_field::Field; use zokrates_field::Field;
@ -53,7 +52,7 @@ pub struct RedefinitionOptimizer<T: Field> {
} }
impl<T: Field> RedefinitionOptimizer<T> { impl<T: Field> RedefinitionOptimizer<T> {
fn new() -> RedefinitionOptimizer<T> { fn new() -> Self {
RedefinitionOptimizer { RedefinitionOptimizer {
substitution: HashMap::new(), substitution: HashMap::new(),
ignore: HashSet::new(), ignore: HashSet::new(),
@ -72,84 +71,77 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
let quad = self.fold_quadratic_combination(quad); let quad = self.fold_quadratic_combination(quad);
let lin = self.fold_linear_combination(lin); let lin = self.fold_linear_combination(lin);
let (keep_constraint, to_insert, to_ignore) = match lin.try_summand() { if lin.is_zero() {
// if the right side is a single variable return vec![Statement::Constraint(quad, lin)];
Some((variable, coefficient)) => { }
match self.ignore.contains(&variable) {
// if the variable isn't tagged as ignored let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0)
false => match self.substitution.get(&variable) { || self.substitution.contains_key(&lin.0[0].0)
// if the variable is already defined {
Some(_) => (true, None, None), true => (Some(Statement::Constraint(quad, lin)), None, None),
// if the variable is not defined yet false => match lin.try_summand() {
None => match quad.try_linear() { // if the right side is a single variable
// if the left side is linear Ok((variable, coefficient)) => match quad.try_linear() {
Some(l) => (false, Some((variable, l / &coefficient)), None), // if the left side is linear
// if the left side isn't linear Ok(l) => (None, Some((variable, l / &coefficient)), None),
None => (true, None, Some(variable)), // if the left side isn't linear
}, Err(quad) => (
}, Some(Statement::Constraint(
true => (true, None, None), quad,
} LinComb::summand(coefficient, variable),
} )),
None => (true, None, None), None,
Some(variable),
),
},
Err(l) => (Some(Statement::Constraint(quad, l)), None, None),
},
}; };
// insert into the ignored set // insert into the ignored set
match to_ignore { if let Some(v) = to_ignore {
Some(v) => { self.ignore.insert(v);
self.ignore.insert(v);
}
None => {}
} }
// insert into the substitution map // insert into the substitution map
match to_insert { if let Some((k, v)) = to_insert {
Some((k, v)) => { self.substitution.insert(k, v.into_canonical());
self.substitution.insert(k, v.into_canonical());
}
None => {}
}; };
// decide whether the constraint should be kept // decide whether the constraint should be kept
match keep_constraint { match constraint {
false => vec![], Some(c) => vec![c],
true => vec![Statement::Constraint(quad, lin)], _ => vec![],
} }
} }
Statement::Directive(d) => { Statement::Directive(d) => {
let d = self.fold_directive(d); let d = self.fold_directive(d);
// check if the inputs are constants, ie reduce to the form `coeff * ~one` // check if the inputs are constants, ie reduce to the form `coeff * ~one`
let inputs = d let inputs: Vec<_> = d
.inputs .inputs
.into_iter() .into_iter()
// we need to reduce to the canonical form to interpret `a + 1 - a` as `1` // we need to reduce to the canonical form to interpret `a + 1 - a` as `1`
.map(|i| i.reduce()) .map(|i| i.reduce())
.map(|q| match q.try_linear() { .map(|q| {
Some(l) => match l.0.len() { match q
// 0 is constant and can be represented by an empty lincomb .try_linear()
0 => Ok(T::from(0)), .map(|l| l.try_constant().map_err(|l| l.into()))
_ => l {
// try to match to a single summand `coeff * v` Ok(r) => r,
.try_summand() Err(e) => Err(e),
.map(|(variable, coefficient)| match variable { }
// v must be ~one
v if v == FlatVariable::one() => Ok(coefficient),
_ => Err(LinComb::summand(coefficient, variable).into()),
})
.unwrap_or(Err(l.into())),
},
None => Err(q),
}) })
.collect::<Vec<Result<T, QuadComb<T>>>>(); .collect::<Vec<Result<T, QuadComb<T>>>>();
match inputs.iter().all(|r| r.is_ok()) { match inputs.iter().all(|i| i.is_ok()) {
true => { true => {
// unwrap inputs to their constant value // unwrap inputs to their constant value
let inputs = inputs.into_iter().map(|i| i.unwrap()).collect(); let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver // run the solver
let outputs = d.solver.execute(&inputs).unwrap(); let outputs = Interpreter::default()
.execute_solver(&d.solver, &inputs)
.unwrap();
assert_eq!(outputs.len(), d.outputs.len()); assert_eq!(outputs.len(), d.outputs.len());
// insert the results in the substitution // insert the results in the substitution
@ -160,8 +152,8 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
vec![] vec![]
} }
false => { false => {
// reconstruct the input expressions //reconstruct the input expressions
let inputs = inputs let inputs: Vec<_> = inputs
.into_iter() .into_iter()
.map(|i| { .map(|i| {
i.map(|v| LinComb::summand(v, FlatVariable::one()).into()) i.map(|v| LinComb::summand(v, FlatVariable::one()).into())
@ -183,8 +175,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
match lc match lc
.0 .0
.iter() .iter()
.find(|(variable, _)| self.substitution.get(&variable).is_some()) .any(|(variable, _)| self.substitution.get(&variable).is_some())
.is_some()
{ {
true => true =>
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is // for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
@ -194,7 +185,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
self.substitution self.substitution
.get(&variable) .get(&variable)
.map(|l| LinComb::from(l.clone()) * &coefficient) .map(|l| LinComb::from(l.clone()) * &coefficient)
.unwrap_or(LinComb::summand(coefficient, variable)) .unwrap_or_else(|| LinComb::summand(coefficient, variable))
}) })
.fold(LinComb::zero(), |acc, x| acc + x) .fold(LinComb::zero(), |acc, x| acc + x)
} }
@ -209,9 +200,6 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
} }
fn fold_function(&mut self, fun: Function<T>) -> Function<T> { fn fold_function(&mut self, fun: Function<T>) -> Function<T> {
self.substitution.drain();
self.ignore.drain();
// to prevent the optimiser from replacing outputs, add them to the ignored set // to prevent the optimiser from replacing outputs, add them to the ignored set
self.ignore.extend(fun.returns.iter().cloned()); self.ignore.extend(fun.returns.iter().cloned());
@ -378,7 +366,7 @@ mod tests {
// -> // ->
// def main(x, y) -> (1): // def main(x, y) -> (1):
// 6*x + 6*y == 6*x + 6*y // will be eliminated as a tautology // 1*x + 1*y + 2*x + 2*y + 3*x + 3*y == 6*x + 6*y // will be eliminated as a tautology
// return 6*x + 6*y // return 6*x + 6*y
let x = FlatVariable::new(0); let x = FlatVariable::new(0);
@ -412,7 +400,15 @@ mod tests {
LinComb::summand(6, x) + LinComb::summand(6, y), LinComb::summand(6, x) + LinComb::summand(6, y),
LinComb::summand(6, x) + LinComb::summand(6, y), LinComb::summand(6, x) + LinComb::summand(6, y),
), ),
Statement::definition(r, LinComb::summand(6, x) + LinComb::summand(6, y)), Statement::definition(
r,
LinComb::summand(1, x)
+ LinComb::summand(1, y)
+ LinComb::summand(2, x)
+ LinComb::summand(2, y)
+ LinComb::summand(3, x)
+ LinComb::summand(3, y),
),
], ],
returns: vec![r], returns: vec![r],
}; };

View file

@ -25,17 +25,16 @@ impl TautologyOptimizer {
impl<T: Field> Folder<T> for TautologyOptimizer { impl<T: Field> Folder<T> for TautologyOptimizer {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> { fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
match s { match s {
Statement::Constraint(quad, lin) => { Statement::Constraint(quad, lin) => match quad.try_linear() {
match quad.try_linear() { Ok(l) => {
Some(l) => { if l == lin {
if l == lin { vec![]
return vec![]; } else {
} vec![Statement::Constraint(l.into(), lin)]
} }
None => {}
} }
vec![Statement::Constraint(quad, lin)] Err(quad) => vec![Statement::Constraint(quad, lin)],
} },
_ => fold_statement(self, s), _ => fold_statement(self, s),
} }
} }

View file

@ -78,7 +78,7 @@ impl<T: Field + ArkFieldExtensions + NotBw6_761Field> Backend<T, GM17> for Ark {
query: vk query: vk
.query .query
.into_iter() .into_iter()
.map(|g1| serialization::to_g1::<T>(g1)) .map(serialization::to_g1::<T>)
.collect(), .collect(),
}; };
@ -172,7 +172,7 @@ impl Backend<Bw6_761Field, GM17> for Ark {
query: vk query: vk
.query .query
.into_iter() .into_iter()
.map(|g1| serialization::to_g1::<Bw6_761Field>(g1)) .map(serialization::to_g1::<Bw6_761Field>)
.collect(), .collect(),
}; };

View file

@ -49,42 +49,33 @@ fn ark_combination<T: Field + ArkFieldExtensions>(
cs: &mut ConstraintSystem<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>, cs: &mut ConstraintSystem<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>,
symbols: &mut BTreeMap<FlatVariable, Variable>, symbols: &mut BTreeMap<FlatVariable, Variable>,
witness: &mut Witness<T>, witness: &mut Witness<T>,
) -> Result< ) -> LinearCombination<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr> {
LinearCombination<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>, l.0.into_iter()
SynthesisError, .map(|(k, v)| {
> { (
let lc = v.into_ark(),
l.0.into_iter() *symbols.entry(k).or_insert_with(|| {
.map(|(k, v)| { match k.is_output() {
( true => cs.new_input_variable(|| {
v.into_ark(), Ok(witness
symbols .0
.entry(k) .remove(&k)
.or_insert_with(|| { .ok_or(SynthesisError::AssignmentMissing)?
match k.is_output() { .into_ark())
true => cs.new_input_variable(|| { }),
Ok(witness false => cs.new_witness_variable(|| {
.0 Ok(witness
.remove(&k) .0
.ok_or(SynthesisError::AssignmentMissing)? .remove(&k)
.into_ark()) .ok_or(SynthesisError::AssignmentMissing)?
}), .into_ark())
false => cs.new_witness_variable(|| { }),
Ok(witness }
.0 .unwrap()
.remove(&k) }),
.ok_or(SynthesisError::AssignmentMissing)? )
.into_ark()) })
}), .fold(LinearCombination::zero(), |acc, e| acc + e)
}
.unwrap()
})
.clone(),
)
})
.fold(LinearCombination::zero(), |acc, e| acc + e);
Ok(lc)
} }
impl<T: Field + ArkFieldExtensions> Prog<T> { impl<T: Field + ArkFieldExtensions> Prog<T> {
@ -96,7 +87,7 @@ impl<T: Field + ArkFieldExtensions> Prog<T> {
// mapping from IR variables // mapping from IR variables
let mut symbols = BTreeMap::new(); let mut symbols = BTreeMap::new();
let mut witness = witness.unwrap_or(Witness::empty()); let mut witness = witness.unwrap_or_else(Witness::empty);
assert!(symbols.insert(FlatVariable::one(), ConstraintSystem::<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>::one()).is_none()); assert!(symbols.insert(FlatVariable::one(), ConstraintSystem::<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>::one()).is_none());
@ -127,37 +118,34 @@ impl<T: Field + ArkFieldExtensions> Prog<T> {
}), }),
} }
.unwrap(); .unwrap();
(var.clone(), wire) (*var, wire)
}), }),
); );
let main = self.main; let main = self.main;
for statement in main.statements { for statement in main.statements {
match statement { if let Statement::Constraint(quad, lin) = statement {
Statement::Constraint(quad, lin) => { let a = ark_combination(
let a = ark_combination( quad.left.clone().into_canonical(),
quad.left.clone().into_canonical(), &mut cs,
&mut cs, &mut symbols,
&mut symbols, &mut witness,
&mut witness, );
)?; let b = ark_combination(
let b = ark_combination( quad.right.clone().into_canonical(),
quad.right.clone().into_canonical(), &mut cs,
&mut cs, &mut symbols,
&mut symbols, &mut witness,
&mut witness, );
)?; let c = ark_combination(
let c = ark_combination( lin.into_canonical(),
lin.into_canonical(), &mut cs,
&mut cs, &mut symbols,
&mut symbols, &mut witness,
&mut witness, );
)?;
cs.enforce_constraint(a, b, c)?; cs.enforce_constraint(a, b, c)?;
}
_ => {}
} }
} }

View file

@ -81,7 +81,7 @@ impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
ic: vk ic: vk
.gamma_abc .gamma_abc
.into_iter() .into_iter()
.map(|g1| serialization::to_g1::<T>(g1)) .map(serialization::to_g1::<T>)
.collect(), .collect(),
}; };

View file

@ -51,34 +51,31 @@ fn bellman_combination<T: BellmanFieldExtensions, CS: ConstraintSystem<T::Bellma
.map(|(k, v)| { .map(|(k, v)| {
( (
v.into_bellman(), v.into_bellman(),
symbols *symbols.entry(k).or_insert_with(|| {
.entry(k) match k.is_output() {
.or_insert_with(|| { true => cs.alloc_input(
match k.is_output() { || format!("{}", k),
true => cs.alloc_input( || {
|| format!("{}", k), Ok(witness
|| { .0
Ok(witness .remove(&k)
.0 .ok_or(SynthesisError::AssignmentMissing)?
.remove(&k) .into_bellman())
.ok_or(SynthesisError::AssignmentMissing)? },
.into_bellman()) ),
}, false => cs.alloc(
), || format!("{}", k),
false => cs.alloc( || {
|| format!("{}", k), Ok(witness
|| { .0
Ok(witness .remove(&k)
.0 .ok_or(SynthesisError::AssignmentMissing)?
.remove(&k) .into_bellman())
.ok_or(SynthesisError::AssignmentMissing)? },
.into_bellman()) ),
}, }
), .unwrap()
} }),
.unwrap()
})
.clone(),
) )
}) })
.fold(LinearCombination::zero(), |acc, e| acc + e) .fold(LinearCombination::zero(), |acc, e| acc + e)
@ -93,7 +90,7 @@ impl<T: BellmanFieldExtensions + Field> Prog<T> {
// mapping from IR variables // mapping from IR variables
let mut symbols = BTreeMap::new(); let mut symbols = BTreeMap::new();
let mut witness = witness.unwrap_or(Witness::empty()); let mut witness = witness.unwrap_or_else(Witness::empty);
assert!(symbols.insert(FlatVariable::one(), CS::one()).is_none()); assert!(symbols.insert(FlatVariable::one(), CS::one()).is_none());
@ -127,33 +124,29 @@ impl<T: BellmanFieldExtensions + Field> Prog<T> {
), ),
} }
.unwrap(); .unwrap();
(var.clone(), wire) (*var, wire)
}), }),
); );
let main = self.main; let main = self.main;
for statement in main.statements { for statement in main.statements {
match statement { if let Statement::Constraint(quad, lin) = statement {
Statement::Constraint(quad, lin) => { let a = &bellman_combination(
let a = &bellman_combination( quad.left.into_canonical(),
quad.left.into_canonical(), cs,
cs, &mut symbols,
&mut symbols, &mut witness,
&mut witness, );
); let b = &bellman_combination(
let b = &bellman_combination( quad.right.into_canonical(),
quad.right.into_canonical(), cs,
cs, &mut symbols,
&mut symbols, &mut witness,
&mut witness, );
); let c = &bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness);
let c =
&bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness);
cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c); cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c);
}
_ => {}
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,5 @@
#[cfg(feature = "bellman")]
use pairing_ce::bn256::Bn256;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
#[cfg(feature = "bellman")]
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::{Bn128Field, Field};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Solver { pub enum Solver {
@ -48,168 +43,3 @@ impl Solver {
Solver::Bits(width) Solver::Bits(width)
} }
} }
pub trait Executable<T> {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String>;
}
impl<T: Field> Executable<T> for Solver {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = self.get_signature();
assert_eq!(inputs.len(), expected_input_count);
let res = match self {
Solver::ConditionEq => match inputs[0].is_zero() {
true => vec![T::zero(), T::one()],
false => vec![
T::one(),
T::one().checked_div(&inputs[0]).unwrap_or(T::one()),
],
},
Solver::Bits(bit_width) => {
let mut num = inputs[0].clone();
let mut res = vec![];
for i in (0..*bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
} else {
res.push(T::zero());
}
}
res
}
Solver::Xor => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - T::from(2) * x * y]
}
Solver::Or => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - x * y]
}
// res = b * c - (2b * c - b - c) * (a)
Solver::ShaAndXorAndXorAnd => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a]
}
// res = a(b - c) + c
Solver::ShaCh => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![a * (b - c.clone()) + c]
}
Solver::Div => vec![inputs[0]
.clone()
.checked_div(&inputs[1])
.unwrap_or(T::one())],
Solver::EuclideanDiv => {
use num::CheckedDiv;
let n = inputs[0].clone().to_biguint();
let d = inputs[1].clone().to_biguint();
let q = n.checked_div(&d).unwrap_or(0u32.into());
let r = n - d * &q;
vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()]
}
#[cfg(feature = "bellman")]
Solver::Sha256Round => {
assert_eq!(T::id(), Bn128Field::id());
let i = &inputs[0..512];
let h = &inputs[512..];
let to_fr = |x: &T| {
use pairing_ce::ff::{PrimeField, ScalarEngine};
let s = x.to_dec_string();
<Bn256 as ScalarEngine>::Fr::from_str(&s).unwrap()
};
let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect();
let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect();
assert_eq!(h.len(), 256);
generate_sha256_round_witness::<Bn256>(&i, &h)
.into_iter()
.map(|x| {
use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr};
let mut res: Vec<u8> = vec![];
x.into_repr().write_le(&mut res).unwrap();
T::from_byte_vector(res)
})
.collect()
}
};
assert_eq!(res.len(), expected_output_count);
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
mod eq_condition {
// Wanted: (Y = (X != 0) ? 1 : 0)
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
use super::*;
#[test]
fn execute() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![0];
let r = cond_eq
.execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect())
.unwrap();
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
#[test]
fn execute_non_eq() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![1];
let r = cond_eq
.execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect())
.unwrap();
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
}
#[test]
fn bits_of_one() {
let bits = Solver::Bits(Bn128Field::get_required_bits());
let inputs = vec![Bn128Field::from(1)];
let res = bits.execute(&inputs).unwrap();
assert_eq!(res[253], Bn128Field::from(1));
for i in 0..253 {
assert_eq!(res[i], Bn128Field::from(0));
}
}
#[test]
fn bits_of_42() {
let bits = Solver::Bits(Bn128Field::get_required_bits());
let inputs = vec![Bn128Field::from(42)];
let res = bits.execute(&inputs).unwrap();
assert_eq!(res[253], Bn128Field::from(0));
assert_eq!(res[252], Bn128Field::from(1));
assert_eq!(res[251], Bn128Field::from(0));
assert_eq!(res[250], Bn128Field::from(1));
assert_eq!(res[249], Bn128Field::from(0));
assert_eq!(res[248], Bn128Field::from(1));
assert_eq!(res[247], Bn128Field::from(0));
}
}

View file

@ -0,0 +1,134 @@
use crate::typed_absy::result_folder::*;
use crate::typed_absy::*;
use zokrates_field::Field;
pub struct BoundsChecker;
pub type Error = String;
impl BoundsChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
BoundsChecker.fold_program(p)
}
pub fn check_select<'ast, T: Field, U: Select<'ast, T>>(
&mut self,
array: ArrayExpression<'ast, T>,
index: UExpression<'ast, T>,
) -> Result<U, Error> {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
match (array.get_array_type().size.as_inner(), index.as_inner()) {
(UExpressionInner::Value(size), UExpressionInner::Value(index)) => {
if index >= size {
return Err(format!(
"Out of bounds access: {}[{}] but {} is of size {}",
array, index, array, size
));
}
}
_ => unreachable!(),
};
Ok(U::select(array, index))
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
type Error = Error;
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::Select(box array, box index) => self
.check_select::<_, ArrayExpression<_>>(array, index)
.map(|a| a.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = self.fold_array_expression(array)?;
let from = self.fold_uint_expression(from)?;
let to = self.fold_uint_expression(to)?;
match (
array.get_array_type().size.as_inner(),
from.as_inner(),
to.as_inner(),
) {
(
UExpressionInner::Value(size),
UExpressionInner::Value(from),
UExpressionInner::Value(to),
) => {
if from > to {
return Err(format!(
"Slice is created from an invalid range {}..{}",
from, to
));
}
if from > size {
return Err(format!("Lower bound {} of slice {}[{}..{}] is out of bounds for array of size {}", from, array, from, to, size));
}
if to > size {
return Err(format!("Upper bound {} of slice {}[{}..{}] is out of bounds for array of size {}", to, array, from, to, size));
}
}
_ => unreachable!(),
};
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
}
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
match e {
StructExpressionInner::Select(box array, box index) => self
.check_select::<_, StructExpression<_>>(array, index)
.map(|a| a.into_inner()),
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::Select(box array, box index) => self.check_select(array, index),
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::Select(box array, box index) => self.check_select(array, index),
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
match e {
UExpressionInner::Select(box array, box index) => self
.check_select::<_, UExpression<_>>(array, index)
.map(|a| a.into_inner()),
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
}

View file

@ -1,31 +1,34 @@
use crate::typed_absy; use crate::typed_absy;
use crate::typed_absy::types::{StructType, UBitwidth}; use crate::typed_absy::types::UBitwidth;
use crate::zir; use crate::zir;
use std::marker::PhantomData; use std::marker::PhantomData;
use zokrates_field::Field; use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
pub struct Flattener<T: Field> { pub struct Flattener<T: Field> {
phantom: PhantomData<T>, phantom: PhantomData<T>,
} }
fn flatten_identifier_rec<'a>( fn flatten_identifier_rec<'ast>(
id: zir::SourceIdentifier<'a>, id: zir::SourceIdentifier<'ast>,
ty: &typed_absy::Type, ty: &typed_absy::types::ConcreteType,
) -> Vec<zir::Variable<'a>> { ) -> Vec<zir::Variable<'ast>> {
match ty { match ty {
typed_absy::Type::FieldElement => vec![zir::Variable { typed_absy::ConcreteType::Int => unreachable!(),
typed_absy::ConcreteType::FieldElement => vec![zir::Variable {
id: zir::Identifier::Source(id), id: zir::Identifier::Source(id),
_type: zir::Type::FieldElement, _type: zir::Type::FieldElement,
}], }],
typed_absy::Type::Boolean => vec![zir::Variable { typed_absy::types::ConcreteType::Boolean => vec![zir::Variable {
id: zir::Identifier::Source(id), id: zir::Identifier::Source(id),
_type: zir::Type::Boolean, _type: zir::Type::Boolean,
}], }],
typed_absy::Type::Uint(bitwidth) => vec![zir::Variable { typed_absy::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable {
id: zir::Identifier::Source(id), id: zir::Identifier::Source(id),
_type: zir::Type::uint(bitwidth.to_usize()), _type: zir::Type::uint(bitwidth.to_usize()),
}], }],
typed_absy::Type::Array(array_type) => (0..array_type.size) typed_absy::types::ConcreteType::Array(array_type) => (0..array_type.size)
.flat_map(|i| { .flat_map(|i| {
flatten_identifier_rec( flatten_identifier_rec(
zir::SourceIdentifier::Select(box id.clone(), i), zir::SourceIdentifier::Select(box id.clone(), i),
@ -33,7 +36,7 @@ fn flatten_identifier_rec<'a>(
) )
}) })
.collect(), .collect(),
typed_absy::Type::Struct(members) => members typed_absy::types::ConcreteType::Struct(members) => members
.iter() .iter()
.flat_map(|struct_member| { .flat_map(|struct_member| {
flatten_identifier_rec( flatten_identifier_rec(
@ -57,17 +60,6 @@ impl<'ast, T: Field> Flattener<T> {
fold_program(self, p) fold_program(self, p)
} }
fn fold_module(&mut self, p: typed_absy::TypedModule<'ast, T>) -> zir::ZirModule<'ast, T> {
fold_module(self, p)
}
fn fold_function_symbol(
&mut self,
s: typed_absy::TypedFunctionSymbol<'ast, T>,
) -> zir::ZirFunctionSymbol<'ast, T> {
fold_function_symbol(self, s)
}
fn fold_function( fn fold_function(
&mut self, &mut self,
f: typed_absy::TypedFunction<'ast, T>, f: typed_absy::TypedFunction<'ast, T>,
@ -75,9 +67,12 @@ impl<'ast, T: Field> Flattener<T> {
fold_function(self, f) fold_function(self, f)
} }
fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> { fn fold_declaration_parameter(
&mut self,
p: typed_absy::DeclarationParameter<'ast>,
) -> Vec<zir::Parameter<'ast>> {
let private = p.private; let private = p.private;
self.fold_variable(p.id) self.fold_variable(p.id.try_into().unwrap())
.into_iter() .into_iter()
.map(|v| zir::Parameter { id: v, private }) .map(|v| zir::Parameter { id: v, private })
.collect() .collect()
@ -87,10 +82,12 @@ impl<'ast, T: Field> Flattener<T> {
zir::SourceIdentifier::Basic(n) zir::SourceIdentifier::Basic(n)
} }
fn fold_variable(&mut self, v: typed_absy::Variable<'ast>) -> Vec<zir::Variable<'ast>> { fn fold_variable(&mut self, v: typed_absy::Variable<'ast, T>) -> Vec<zir::Variable<'ast>> {
let id = self.fold_name(v.id.clone()); let id = self.fold_name(v.id.clone());
let ty = v.get_type(); let ty = v.get_type();
let ty = typed_absy::types::ConcreteType::try_from(ty).unwrap();
flatten_identifier_rec(id, &ty) flatten_identifier_rec(id, &ty)
} }
@ -102,36 +99,34 @@ impl<'ast, T: Field> Flattener<T> {
typed_absy::TypedAssignee::Identifier(v) => self.fold_variable(v), typed_absy::TypedAssignee::Identifier(v) => self.fold_variable(v),
typed_absy::TypedAssignee::Select(box a, box i) => { typed_absy::TypedAssignee::Select(box a, box i) => {
use typed_absy::Typed; use typed_absy::Typed;
let count = match a.get_type() { let count = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap() {
typed_absy::Type::Array(array_ty) => array_ty.ty.get_primitive_count(), typed_absy::ConcreteType::Array(array_ty) => array_ty.ty.get_primitive_count(),
_ => unreachable!(), _ => unreachable!(),
}; };
let a = self.fold_assignee(a); let a = self.fold_assignee(a);
match i { match i.as_inner() {
typed_absy::FieldElementExpression::Number(n) => { typed_absy::UExpressionInner::Value(index) => {
let index = n.to_dec_string().parse::<usize>().unwrap(); a[*index as usize * count..(*index as usize + 1) * count].to_vec()
a[index * count..(index + 1) * count].to_vec()
} }
i => unreachable!("index {} not allowed, should be a constant", i), i => unreachable!("index {:?} not allowed, should be a constant", i),
} }
} }
typed_absy::TypedAssignee::Member(box a, m) => { typed_absy::TypedAssignee::Member(box a, m) => {
use typed_absy::Typed; use typed_absy::Typed;
let (offset, size) = match a.get_type() { let (offset, size) = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap()
typed_absy::Type::Struct(struct_type) => { {
struct_type typed_absy::ConcreteType::Struct(struct_type) => struct_type
.members .members
.iter() .iter()
.fold((0, None), |(offset, size), member| match size { .fold((0, None), |(offset, size), member| match size {
Some(_) => (offset, size), Some(_) => (offset, size),
None => match m == member.id { None => match m == member.id {
true => (offset, Some(member.ty.get_primitive_count())), true => (offset, Some(member.ty.get_primitive_count())),
false => (offset + member.ty.get_primitive_count(), None), false => (offset + member.ty.get_primitive_count(), None),
}, },
}) }),
}
_ => unreachable!(), _ => unreachable!(),
}; };
@ -151,6 +146,16 @@ impl<'ast, T: Field> Flattener<T> {
fold_statement(self, s) fold_statement(self, s)
} }
fn fold_expression_or_spread(
&mut self,
e: typed_absy::TypedExpressionOrSpread<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match e {
typed_absy::TypedExpressionOrSpread::Expression(e) => self.fold_expression(e),
typed_absy::TypedExpressionOrSpread::Spread(s) => self.fold_array_expression(s.array),
}
}
fn fold_expression( fn fold_expression(
&mut self, &mut self,
e: typed_absy::TypedExpression<'ast, T>, e: typed_absy::TypedExpression<'ast, T>,
@ -161,8 +166,9 @@ impl<'ast, T: Field> Flattener<T> {
} }
typed_absy::TypedExpression::Boolean(e) => vec![self.fold_boolean_expression(e).into()], typed_absy::TypedExpression::Boolean(e) => vec![self.fold_boolean_expression(e).into()],
typed_absy::TypedExpression::Uint(e) => vec![self.fold_uint_expression(e).into()], typed_absy::TypedExpression::Uint(e) => vec![self.fold_uint_expression(e).into()],
typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e).into(), typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e),
typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e).into(), typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e),
typed_absy::TypedExpression::Int(_) => unreachable!(),
} }
} }
@ -185,26 +191,20 @@ impl<'ast, T: Field> Flattener<T> {
es: typed_absy::TypedExpressionList<'ast, T>, es: typed_absy::TypedExpressionList<'ast, T>,
) -> zir::ZirExpressionList<'ast, T> { ) -> zir::ZirExpressionList<'ast, T> {
match es { match es {
typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => { typed_absy::TypedExpressionList::EmbedCall(embed, generics, arguments, _) => {
zir::ZirExpressionList::FunctionCall( zir::ZirExpressionList::EmbedCall(
self.fold_function_key(id), embed,
generics,
arguments arguments
.into_iter() .into_iter()
.flat_map(|a| self.fold_expression(a)) .flat_map(|a| self.fold_expression(a))
.collect(), .collect(),
vec![],
) )
} }
_ => unreachable!("should have been inlined"),
} }
} }
fn fold_function_key(
&mut self,
k: typed_absy::types::FunctionKey<'ast>,
) -> zir::types::FunctionKey<'ast> {
k.into()
}
fn fold_field_expression( fn fold_field_expression(
&mut self, &mut self,
e: typed_absy::FieldElementExpression<'ast, T>, e: typed_absy::FieldElementExpression<'ast, T>,
@ -234,7 +234,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_array_expression_inner( fn fold_array_expression_inner(
&mut self, &mut self,
ty: &typed_absy::Type, ty: &typed_absy::types::ConcreteType,
size: usize, size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>, e: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
@ -242,26 +242,13 @@ impl<'ast, T: Field> Flattener<T> {
} }
fn fold_struct_expression_inner( fn fold_struct_expression_inner(
&mut self, &mut self,
ty: &StructType, ty: &typed_absy::types::ConcreteStructType,
e: typed_absy::StructExpressionInner<'ast, T>, e: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
fold_struct_expression_inner(self, ty, e) fold_struct_expression_inner(self, ty, e)
} }
} }
pub fn fold_module<'ast, T: Field>(
f: &mut Flattener<T>,
p: typed_absy::TypedModule<'ast, T>,
) -> zir::ZirModule<'ast, T> {
zir::ZirModule {
functions: p
.functions
.into_iter()
.map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun)))
.collect(),
}
}
pub fn fold_statement<'ast, T: Field>( pub fn fold_statement<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
s: typed_absy::TypedStatement<'ast, T>, s: typed_absy::TypedStatement<'ast, T>,
@ -284,9 +271,7 @@ pub fn fold_statement<'ast, T: Field>(
} }
typed_absy::TypedStatement::Declaration(v) => { typed_absy::TypedStatement::Declaration(v) => {
let v = f.fold_variable(v); let v = f.fold_variable(v);
v.into_iter() v.into_iter().map(zir::ZirStatement::Declaration).collect()
.map(|v| zir::ZirStatement::Declaration(v))
.collect()
} }
typed_absy::TypedStatement::Assertion(e) => { typed_absy::TypedStatement::Assertion(e) => {
let e = f.fold_boolean_expression(e); let e = f.fold_boolean_expression(e);
@ -302,19 +287,23 @@ pub fn fold_statement<'ast, T: Field>(
f.fold_expression_list(elist), f.fold_expression_list(elist),
)] )]
} }
typed_absy::TypedStatement::PushCallLog(..) => vec![],
typed_absy::TypedStatement::PopCallLog => vec![],
} }
} }
pub fn fold_array_expression_inner<'ast, T: Field>( pub fn fold_array_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
t: &typed_absy::Type, ty: &typed_absy::types::ConcreteType,
size: usize, size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>, array: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
match e { match array {
typed_absy::ArrayExpressionInner::Identifier(id) => { typed_absy::ArrayExpressionInner::Identifier(id) => {
let variables = let variables = flatten_identifier_rec(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::array(t.clone(), size)); f.fold_name(id),
&typed_absy::types::ConcreteType::array((ty.clone(), size)),
);
variables variables
.into_iter() .into_iter()
.map(|v| match v._type { .map(|v| match v._type {
@ -326,10 +315,16 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
}) })
.collect() .collect()
} }
typed_absy::ArrayExpressionInner::Value(exprs) => exprs typed_absy::ArrayExpressionInner::Value(exprs) => {
.into_iter() let exprs: Vec<_> = exprs
.flat_map(|e| f.fold_expression(e)) .into_iter()
.collect(), .flat_map(|e| f.fold_expression_or_spread(e))
.collect();
assert_eq!(exprs.len(), size * ty.get_primitive_count());
exprs
}
typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(), typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(),
typed_absy::ArrayExpressionInner::IfElse( typed_absy::ArrayExpressionInner::IfElse(
box condition, box condition,
@ -369,40 +364,74 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count()) .map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
// we also need the size of this member // we also need the size of this member
let size = t.get_primitive_count() * size; let size = ty.get_primitive_count() * size;
s[offset..offset + size].to_vec() s[offset..offset + size].to_vec()
} }
typed_absy::ArrayExpressionInner::Select(box array, box index) => { typed_absy::ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index); let index = f.fold_uint_expression(index);
match index { match index.into_inner() {
zir::FieldElementExpression::Number(i) => { zir::UExpressionInner::Value(i) => {
let size = t.get_primitive_count() * size; let size = ty.clone().get_primitive_count() * size;
let start = i.to_dec_string().parse::<usize>().unwrap() * size; let start = i as usize * size;
let end = start + size; let end = start + size;
array[start..end].to_vec() array[start..end].to_vec()
} }
_ => unreachable!(), _ => unreachable!(),
} }
} }
typed_absy::ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = f.fold_array_expression(array);
let from = f.fold_uint_expression(from);
let to = f.fold_uint_expression(to);
match (from.into_inner(), to.into_inner()) {
(zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => {
assert_eq!(size, to.saturating_sub(from) as usize);
let element_size = ty.get_primitive_count();
let start = from as usize * element_size;
let end = to as usize * element_size;
array[start..end].to_vec()
}
_ => unreachable!(),
}
}
typed_absy::ArrayExpressionInner::Repeat(box e, box count) => {
let e = f.fold_expression(e);
let count = f.fold_uint_expression(count);
match count.into_inner() {
zir::UExpressionInner::Value(count) => {
vec![e; count as usize].into_iter().flatten().collect()
}
_ => unreachable!(),
}
}
} }
} }
pub fn fold_struct_expression_inner<'ast, T: Field>( pub fn fold_struct_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
t: &StructType, ty: &typed_absy::types::ConcreteStructType,
e: typed_absy::StructExpressionInner<'ast, T>, struc: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
match e { match struc {
typed_absy::StructExpressionInner::Identifier(id) => { typed_absy::StructExpressionInner::Identifier(id) => {
let variables = let variables = flatten_identifier_rec(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::struc(t.clone())); f.fold_name(id),
&typed_absy::types::ConcreteType::struc(ty.clone()),
);
variables variables
.into_iter() .into_iter()
.map(|v| match v._type { .map(|v| match v._type {
@ -457,13 +486,18 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count()) .map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
// we also need the size of this member // we also need the size of this member
let size = t let size = ty
.iter() .iter()
.find(|member| member.id == id) .find(|member| member.id == id)
.cloned()
.unwrap() .unwrap()
.ty .ty
.get_primitive_count(); .get_primitive_count();
@ -472,15 +506,12 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
} }
typed_absy::StructExpressionInner::Select(box array, box index) => { typed_absy::StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index); let index = f.fold_uint_expression(index);
match index { match index.into_inner() {
zir::FieldElementExpression::Number(i) => { zir::UExpressionInner::Value(i) => {
let size = t let size: usize = ty.iter().map(|m| m.ty.get_primitive_count()).sum();
.iter() let start = i as usize * size;
.map(|m| m.ty.get_primitive_count())
.fold(0, |acc, current| acc + current);
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
let end = start + size; let end = start + size;
array[start..end].to_vec() array[start..end].to_vec()
} }
@ -498,9 +529,12 @@ pub fn fold_field_expression<'ast, T: Field>(
typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n), typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
typed_absy::FieldElementExpression::Identifier(id) => { typed_absy::FieldElementExpression::Identifier(id) => {
zir::FieldElementExpression::Identifier( zir::FieldElementExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::FieldElement)[0] flatten_identifier_rec(
.id f.fold_name(id),
.clone(), &typed_absy::types::ConcreteType::FieldElement,
)[0]
.id
.clone(),
) )
} }
typed_absy::FieldElementExpression::Add(box e1, box e2) => { typed_absy::FieldElementExpression::Add(box e1, box e2) => {
@ -525,7 +559,7 @@ pub fn fold_field_expression<'ast, T: Field>(
} }
typed_absy::FieldElementExpression::Pow(box e1, box e2) => { typed_absy::FieldElementExpression::Pow(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_uint_expression(e2);
zir::FieldElementExpression::Pow(box e1, box e2) zir::FieldElementExpression::Pow(box e1, box e2)
} }
typed_absy::FieldElementExpression::Neg(box e) => { typed_absy::FieldElementExpression::Neg(box e) => {
@ -556,26 +590,22 @@ pub fn fold_field_expression<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count()) .map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap() s[offset].clone().try_into().unwrap()
} }
typed_absy::FieldElementExpression::Select(box array, box index) => { typed_absy::FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index); let index = f.fold_uint_expression(index);
use std::convert::TryInto; match index.into_inner() {
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
_ => unreachable!(""), _ => unreachable!(""),
} }
} }
@ -589,7 +619,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
match e { match e {
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v), typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier( typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Boolean)[0] flatten_identifier_rec(f.fold_name(id), &typed_absy::types::ConcreteType::Boolean)[0]
.id .id
.clone(), .clone(),
), ),
@ -665,25 +695,45 @@ pub fn fold_boolean_expression<'ast, T: Field>(
zir::BooleanExpression::UintEq(box e1, box e2) zir::BooleanExpression::UintEq(box e1, box e2)
} }
typed_absy::BooleanExpression::Lt(box e1, box e2) => { typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Lt(box e1, box e2) zir::BooleanExpression::FieldLt(box e1, box e2)
} }
typed_absy::BooleanExpression::Le(box e1, box e2) => { typed_absy::BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Le(box e1, box e2) zir::BooleanExpression::FieldLe(box e1, box e2)
} }
typed_absy::BooleanExpression::Gt(box e1, box e2) => { typed_absy::BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Gt(box e1, box e2) zir::BooleanExpression::FieldGt(box e1, box e2)
} }
typed_absy::BooleanExpression::Ge(box e1, box e2) => { typed_absy::BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1); let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2); let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Ge(box e1, box e2) zir::BooleanExpression::FieldGe(box e1, box e2)
}
typed_absy::BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLt(box e1, box e2)
}
typed_absy::BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLe(box e1, box e2)
}
typed_absy::BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGt(box e1, box e2)
}
typed_absy::BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGe(box e1, box e2)
} }
typed_absy::BooleanExpression::Or(box e1, box e2) => { typed_absy::BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1); let e1 = f.fold_boolean_expression(e1);
@ -714,25 +764,21 @@ pub fn fold_boolean_expression<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count()) .map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap() s[offset].clone().try_into().unwrap()
} }
typed_absy::BooleanExpression::Select(box array, box index) => { typed_absy::BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index); let index = f.fold_uint_expression(index);
use std::convert::TryInto; match index.into_inner() {
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
_ => unreachable!(), _ => unreachable!(),
} }
} }
@ -755,9 +801,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
match e { match e {
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v), typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier( typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Uint(bitwidth))[0] flatten_identifier_rec(
.id f.fold_name(id),
.clone(), &typed_absy::types::ConcreteType::Uint(bitwidth),
)[0]
.id
.clone(),
), ),
typed_absy::UExpressionInner::Add(box left, box right) => { typed_absy::UExpressionInner::Add(box left, box right) => {
let left = f.fold_uint_expression(left); let left = f.fold_uint_expression(left);
@ -771,6 +820,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
zir::UExpressionInner::Sub(box left, box right) zir::UExpressionInner::Sub(box left, box right)
} }
typed_absy::UExpressionInner::FloorSub(..) => unreachable!(),
typed_absy::UExpressionInner::Mult(box left, box right) => { typed_absy::UExpressionInner::Mult(box left, box right) => {
let left = f.fold_uint_expression(left); let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right); let right = f.fold_uint_expression(right);
@ -827,11 +877,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
typed_absy::UExpressionInner::Neg(box e) => { typed_absy::UExpressionInner::Neg(box e) => {
let bitwidth = e.bitwidth(); let bitwidth = e.bitwidth();
f.fold_uint_expression(typed_absy::UExpression::sub( f.fold_uint_expression(typed_absy::UExpressionInner::Value(0).annotate(bitwidth) - e)
typed_absy::UExpressionInner::Value(0).annotate(bitwidth), .into_inner()
e,
))
.into_inner()
} }
typed_absy::UExpressionInner::Pos(box e) => { typed_absy::UExpressionInner::Pos(box e) => {
@ -844,16 +891,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
} }
typed_absy::UExpressionInner::Select(box array, box index) => { typed_absy::UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array); let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index); let index = f.fold_uint_expression(index);
use std::convert::TryInto; match index.into_inner() {
zir::UExpressionInner::Value(i) => {
match index { let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap();
zir::FieldElementExpression::Number(i) => {
let e: zir::UExpression<_> = array[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap();
e.into_inner() e.into_inner()
} }
_ => unreachable!(), _ => unreachable!(),
@ -867,11 +909,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
let offset: usize = members let offset: usize = members
.iter() .iter()
.take_while(|member| member.id != id) .take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count()) .map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum(); .sum();
use std::convert::TryInto;
let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap(); let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap();
res.into_inner() res.into_inner()
@ -893,14 +937,18 @@ pub fn fold_function<'ast, T: Field>(
arguments: fun arguments: fun
.arguments .arguments
.into_iter() .into_iter()
.flat_map(|a| f.fold_parameter(a)) .flat_map(|a| f.fold_declaration_parameter(a))
.collect(), .collect(),
statements: fun statements: fun
.statements .statements
.into_iter() .into_iter()
.flat_map(|s| f.fold_statement(s)) .flat_map(|s| f.fold_statement(s))
.collect(), .collect(),
signature: fun.signature.into(), signature: typed_absy::types::ConcreteSignature::try_from(
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(),
)
.unwrap()
.into(),
} }
} }
@ -908,41 +956,46 @@ pub fn fold_array_expression<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
e: typed_absy::ArrayExpression<'ast, T>, e: typed_absy::ArrayExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner()) let size = match e.size().into_inner() {
typed_absy::UExpressionInner::Value(v) => v,
_ => unreachable!(),
} as usize;
f.fold_array_expression_inner(
&typed_absy::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(),
size,
e.into_inner(),
)
} }
pub fn fold_struct_expression<'ast, T: Field>( pub fn fold_struct_expression<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
e: typed_absy::StructExpression<'ast, T>, e: typed_absy::StructExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> { ) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner()) f.fold_struct_expression_inner(
} &typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(),
e.into_inner(),
pub fn fold_function_symbol<'ast, T: Field>( )
f: &mut Flattener<T>,
s: typed_absy::TypedFunctionSymbol<'ast, T>,
) -> zir::ZirFunctionSymbol<'ast, T> {
match s {
typed_absy::TypedFunctionSymbol::Here(fun) => {
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
}
typed_absy::TypedFunctionSymbol::There(key, module) => {
zir::ZirFunctionSymbol::There(f.fold_function_key(key), module)
} // by default, do not fold modules recursively
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
}
} }
pub fn fold_program<'ast, T: Field>( pub fn fold_program<'ast, T: Field>(
f: &mut Flattener<T>, f: &mut Flattener<T>,
p: typed_absy::TypedProgram<'ast, T>, mut p: typed_absy::TypedProgram<'ast, T>,
) -> zir::ZirProgram<'ast, T> { ) -> zir::ZirProgram<'ast, T> {
let main_module = p.modules.remove(&p.main).unwrap();
let main_function = main_module
.functions
.into_iter()
.find(|(key, _)| key.id == "main")
.unwrap()
.1;
let main_function = match main_function {
typed_absy::TypedFunctionSymbol::Here(f) => f,
_ => unreachable!(),
};
zir::ZirProgram { zir::ZirProgram {
modules: p main: f.fold_function(main_function),
.modules
.into_iter()
.map(|(module_id, module)| (module_id, f.fold_module(module)))
.collect(),
main: p.main,
} }
} }

File diff suppressed because it is too large Load diff

View file

@ -4,69 +4,93 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr> //! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018 //! @date 2018
mod bounds_checker;
mod flat_propagation; mod flat_propagation;
mod flatten_complex_types; mod flatten_complex_types;
mod inline;
mod propagate_unroll;
mod propagation; mod propagation;
mod redefinition; mod redefinition;
mod return_binder; mod reducer;
mod uint_optimizer; mod uint_optimizer;
mod unconstrained_vars; mod unconstrained_vars;
mod unroll;
mod variable_read_remover; mod variable_read_remover;
mod variable_write_remover; mod variable_write_remover;
use self::bounds_checker::BoundsChecker;
use self::flatten_complex_types::Flattener; use self::flatten_complex_types::Flattener;
use self::inline::Inliner;
use self::propagate_unroll::PropagatedUnroller;
use self::propagation::Propagator; use self::propagation::Propagator;
use self::redefinition::RedefinitionOptimizer; use self::redefinition::RedefinitionOptimizer;
use self::return_binder::ReturnBinder; use self::reducer::reduce_program;
use self::uint_optimizer::UintOptimizer; use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector; use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_read_remover::VariableReadRemover; use self::variable_read_remover::VariableReadRemover;
use self::variable_write_remover::VariableWriteRemover; use self::variable_write_remover::VariableWriteRemover;
use crate::flat_absy::FlatProg; use crate::flat_absy::FlatProg;
use crate::ir::Prog; use crate::ir::Prog;
use crate::typed_absy::TypedProgram; use crate::typed_absy::{abi::Abi, TypedProgram};
use crate::zir::ZirProgram; use crate::zir::ZirProgram;
use std::fmt;
use zokrates_field::Field; use zokrates_field::Field;
pub trait Analyse { pub trait Analyse {
fn analyse(self) -> Self; fn analyse(self) -> Self;
} }
#[derive(Debug)]
pub enum Error {
Reducer(self::reducer::Error),
OutOfBounds(self::bounds_checker::Error),
Propagation(self::propagation::Error),
}
impl From<self::reducer::Error> for Error {
fn from(e: self::reducer::Error) -> Self {
Error::Reducer(e)
}
}
impl From<self::bounds_checker::Error> for Error {
fn from(e: bounds_checker::Error) -> Self {
Error::OutOfBounds(e)
}
}
impl From<self::propagation::Error> for Error {
fn from(e: propagation::Error) -> Self {
Error::Propagation(e)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Reducer(e) => write!(f, "{}", e),
Error::OutOfBounds(e) => write!(f, "{}", e),
Error::Propagation(e) => write!(f, "{}", e),
}
}
}
impl<'ast, T: Field> TypedProgram<'ast, T> { impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self) -> ZirProgram<'ast, T> { pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
// propagated unrolling let r = reduce_program(self).map_err(Error::from)?;
let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!("{}", e));
// return binding let abi = r.abi();
let r = ReturnBinder::bind(r);
// inline
let r = Inliner::inline(r);
// propagate // propagate
let r = Propagator::propagate(r); let r = Propagator::propagate(r).map_err(Error::from)?;
// optimize redefinitions // optimize redefinitions
let r = RedefinitionOptimizer::optimize(r); let r = RedefinitionOptimizer::optimize(r);
// remove assignment to variable index // remove assignment to variable index
let r = VariableWriteRemover::apply(r); let r = VariableWriteRemover::apply(r);
// remove variable access to complex types // remove variable access to complex types
let r = VariableReadRemover::apply(r); let r = VariableReadRemover::apply(r);
// check array accesses are in bounds
let r = BoundsChecker::check(r).map_err(Error::from)?;
// convert to zir, removing complex types // convert to zir, removing complex types
let zir = Flattener::flatten(r); let zir = Flattener::flatten(r);
// optimize uint expressions // optimize uint expressions
let zir = UintOptimizer::optimize(zir); let zir = UintOptimizer::optimize(zir);
zir Ok((zir, abi))
} }
} }
@ -78,7 +102,6 @@ impl<T: Field> Analyse for FlatProg<T> {
impl<T: Field> Analyse for Prog<T> { impl<T: Field> Analyse for Prog<T> {
fn analyse(self) -> Self { fn analyse(self) -> Self {
let r = UnconstrainedVariableDetector::detect(self); UnconstrainedVariableDetector::detect(self)
r
} }
} }

View file

@ -1,236 +0,0 @@
//! Module containing iterative unrolling in order to unroll nested loops with variable bounds
//!
//! For example:
//! ```zokrates
//! for field i in 0..5 do
//! for field j in i..5 do
//! //
//! endfor
//! endfor
//! ```
//!
//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop
//!
//! This module does exactly that:
//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant
//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet*
//! - unroll again, this time the 5 inner loops all have constant bounds
//!
//! In the case that a loop bound cannot be reduced to a constant, we detect it by noticing that the unroll does
//! not make progress anymore.
use crate::static_analysis::propagation::Propagator;
use crate::static_analysis::unroll::{Output, Unroller};
use crate::typed_absy::TypedProgram;
use zokrates_field::Field;
pub struct PropagatedUnroller;
impl PropagatedUnroller {
pub fn unroll<'ast, T: Field>(
p: TypedProgram<'ast, T>,
) -> Result<TypedProgram<'ast, T>, &'static str> {
let mut blocked_at = None;
// unroll a first time, retrieving whether the unroll is complete
let mut unrolled = Unroller::unroll(p);
loop {
// conditions to exit the loop
unrolled = match unrolled {
Output::Complete(p) => return Ok(p),
Output::Incomplete(next, index) => {
if Some(index) == blocked_at {
return Err("Loop unrolling failed. This happened because a loop bound is not constant");
} else {
// update the index where we blocked
blocked_at = Some(index);
// propagate
let propagated = Propagator::propagate_verbose(next);
// unroll
Unroller::unroll(propagated)
}
}
};
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::{FunctionKey, Signature};
use crate::typed_absy::*;
use zokrates_field::Bn128Field;
#[test]
fn detect_non_constant_bound() {
let loops = vec![TypedStatement::For(
Variable::field_element("i"),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![],
)];
let statements = loops;
let p = TypedProgram {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
FunctionKey::with_id("main"),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
signature: Signature::new(),
statements,
}),
)]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
main: "main".into(),
};
assert!(PropagatedUnroller::unroll(p).is_err());
}
#[test]
fn for_loop() {
// for field i in 0..2
// for field j in i..2
// field foo = i + j
// should be unrolled to
// i_0 = 0
// j_0 = 0
// foo_0 = i_0 + j_0
// j_1 = 1
// foo_1 = i_0 + j_1
// i_1 = 1
// j_2 = 1
// foo_2 = i_1 + j_1
let s = TypedStatement::For(
Variable::field_element("i"),
FieldElementExpression::Number(Bn128Field::from(0)),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![TypedStatement::For(
Variable::field_element("j"),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![
TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("foo")),
FieldElementExpression::Add(
box FieldElementExpression::Identifier("i".into()),
box FieldElementExpression::Identifier("j".into()),
)
.into(),
),
],
)],
);
let expected_statements = vec![
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("i").version(0),
)),
FieldElementExpression::Number(Bn128Field::from(0)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(0),
)),
FieldElementExpression::Number(Bn128Field::from(0)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(0),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
box FieldElementExpression::Identifier(Identifier::from("j").version(0)),
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(1),
)),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(1),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
box FieldElementExpression::Identifier(Identifier::from("j").version(1)),
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("i").version(1),
)),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(2),
)),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(2),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(1)),
box FieldElementExpression::Identifier(Identifier::from("j").version(2)),
)
.into(),
),
];
let p = TypedProgram {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
FunctionKey::with_id("main"),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
signature: Signature::new(),
statements: vec![s],
}),
)]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
main: "main".into(),
};
let statements = match PropagatedUnroller::unroll(p).unwrap().modules
[std::path::Path::new("main")]
.functions[&FunctionKey::with_id("main")]
.clone()
{
TypedFunctionSymbol::Here(f) => f.statements,
_ => unreachable!(),
};
assert_eq!(statements, expected_statements);
}
}

File diff suppressed because it is too large Load diff

View file

@ -68,6 +68,6 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast> {
} }
fn fold_name(&mut self, s: Identifier<'ast>) -> Identifier<'ast> { fn fold_name(&mut self, s: Identifier<'ast>) -> Identifier<'ast> {
self.identifiers.get(&s).map(|r| r.clone()).unwrap_or(s) self.identifiers.get(&s).cloned().unwrap_or(s)
} }
} }

View file

@ -0,0 +1,235 @@
// The inlining phase takes a call site (fun::<gen>(args)) and inlines it:
// Given:
// ```
// def foo<n>(field a) -> field:
// a = a
// n = n
// return a
// ```
//
// The call site
// ```
// foo::<42>(x)
// ```
//
// Becomes
// ```
// # Call foo::<42> with a_0 := x
// n_0 = 42
// a_1 = a_0
// n_1 = n_0
// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1
// Notes:
// - The body of the function is in SSA form
// - The return value(s) are assigned to internal variables
use crate::embed::FlatEmbed;
use crate::static_analysis::reducer::Output;
use crate::static_analysis::reducer::ShallowTransformer;
use crate::static_analysis::reducer::Versions;
use crate::typed_absy::types::ConcreteGenericsAssignment;
use crate::typed_absy::CoreIdentifier;
use crate::typed_absy::Identifier;
use crate::typed_absy::TypedAssignee;
use crate::typed_absy::{
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature,
Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression,
UExpressionInner, Variable,
};
use zokrates_field::Field;
pub enum InlineError<'ast, T> {
Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>),
Flat(
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
NonConstant(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
}
fn get_canonical_function<'ast, T: Field>(
function_key: DeclarationFunctionKey<'ast>,
program: &TypedProgram<'ast, T>,
) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) {
match program
.modules
.get(&function_key.module)
.unwrap()
.functions
.iter()
.find(|(key, _)| function_key == **key)
.unwrap()
{
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program),
(key, s) => (key.clone(), s.clone()),
}
}
type InlineResult<'ast, T> = Result<
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
InlineError<'ast, T>,
>;
pub fn inline_call<'a, 'ast, T: Field>(
k: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
program: &TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
) -> InlineResult<'ast, T> {
use std::convert::TryFrom;
use crate::typed_absy::Typed;
// we try to get concrete values for explicit generics
let generics_values: Vec<Option<u32>> = generics
.iter()
.map(|g| {
g.as_ref()
.map(|g| match g.as_inner() {
UExpressionInner::Value(v) => Ok(*v as u32),
_ => Err(()),
})
.transpose()
})
.collect::<Result<_, _>>()
.map_err(|_| {
InlineError::NonConstant(
k.clone(),
generics.clone(),
arguments.clone(),
output_types.clone(),
)
})?;
// we infer a signature based on inputs and outputs
// this is where we could handle explicit annotations
let inferred_signature = Signature::new()
.generics(generics.clone())
.inputs(arguments.iter().map(|a| a.get_type()).collect())
.outputs(output_types.clone());
// we try to get concrete values for the whole signature. if this fails we should propagate again
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
Ok(s) => s,
Err(_) => {
return Err(InlineError::NonConstant(
k,
generics,
arguments,
output_types,
));
}
};
let (decl_key, symbol) = get_canonical_function(k.clone(), program);
// get an assignment of generics for this call site
let assignment: ConcreteGenericsAssignment<'ast> = k
.signature
.specialize(generics_values, &inferred_signature)
.map_err(|_| {
InlineError::Generic(
k.clone(),
ConcreteFunctionKey {
module: decl_key.module.clone(),
id: decl_key.id,
signature: inferred_signature.clone(),
},
)
})?;
let f = match symbol {
TypedFunctionSymbol::Here(f) => Ok(f),
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
e,
e.generics(&assignment),
arguments.clone(),
output_types,
)),
_ => unreachable!(),
}?;
assert_eq!(f.arguments.len(), arguments.len());
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
Output::Complete(v) => (v, None),
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
};
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone());
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
.arguments
.into_iter()
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
.zip(arguments.clone())
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
.statements
.into_iter()
.partition(|s| !matches!(s, TypedStatement::Return(..)));
assert_eq!(returns.len(), 1);
let returns = match returns.pop().unwrap() {
TypedStatement::Return(e) => e,
_ => unreachable!(),
};
let res: Vec<ConcreteVariable<'ast>> = inferred_signature
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
ConcreteVariable::with_id_and_type(
Identifier::from(CoreIdentifier::Call(i)).version(
*versions
.entry(CoreIdentifier::Call(i).clone())
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
),
t.clone(),
)
})
.collect();
let expressions: Vec<TypedExpression<_>> = res
.iter()
.map(|v| TypedExpression::from(Variable::from(v.clone())))
.collect();
assert_eq!(res.len(), returns.len());
let output_bindings: Vec<TypedStatement<'ast, T>> = res
.into_iter()
.zip(returns)
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let pop_log = TypedStatement::PopCallLog;
let statements: Vec<_> = std::iter::once(call_log)
.chain(input_bindings)
.chain(statements)
.chain(output_bindings)
.chain(std::iter::once(pop_log))
.collect();
Ok(incomplete_data
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
.unwrap_or_else(|| Output::Complete((statements, expressions))))
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,58 +0,0 @@
use crate::typed_absy::folder::fold_statement;
use crate::typed_absy::identifier::CoreIdentifier;
use crate::typed_absy::*;
use zokrates_field::Field;
pub struct ReturnBinder;
impl ReturnBinder {
pub fn bind<'ast, T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
ReturnBinder {}.fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Return(exprs) => {
let ret_identifiers: Vec<Identifier<'ast>> = (0..exprs.len())
.map(|i| CoreIdentifier::Internal("RETURN", i).into())
.collect();
let ret_expressions: Vec<TypedExpression<'ast, T>> = exprs
.iter()
.zip(ret_identifiers.iter())
.map(|(e, i)| match e.get_type() {
Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(),
Type::Boolean => BooleanExpression::Identifier(i.clone()).into(),
Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone())
.annotate(*array_type.ty, array_type.size)
.into(),
Type::Struct(struct_type) => StructExpressionInner::Identifier(i.clone())
.annotate(struct_type)
.into(),
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
.annotate(bitwidth)
.into(),
})
.collect();
exprs
.into_iter()
.zip(ret_identifiers.iter())
.map(|(e, i)| {
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
i.clone(),
e.get_type(),
)),
e,
)
})
.chain(std::iter::once(TypedStatement::Return(ret_expressions)))
.collect()
}
s => fold_statement(self, s),
}
}
}

Some files were not shown because too many files have changed in this diff Show more