1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

merge dev, implement inference for neg and pos, fix conclicts

This commit is contained in:
schaeff 2021-04-01 17:34:05 +02:00
commit 9d13b4129d
203 changed files with 12873 additions and 7237 deletions

View file

@ -4,6 +4,7 @@ jobs:
build:
docker:
- image: zokrates/env:latest
resource_class: large
steps:
- checkout
- run:
@ -28,6 +29,7 @@ jobs:
test:
docker:
- image: zokrates/env:latest
resource_class: large
steps:
- checkout
- run:
@ -42,6 +44,9 @@ jobs:
- run:
name: Check format
command: cargo fmt --all -- --check
- run:
name: Run clippy
command: cargo clippy
- run:
name: Build
command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./build.sh
@ -80,6 +85,7 @@ jobs:
docker:
- image: zokrates/env:latest
- image: trufflesuite/ganache-cli:next
resource_class: large
steps:
- checkout
- run:

74
Cargo.lock generated
View file

@ -133,7 +133,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e8cb28c2137af1ef058aa59616db3f7df67dbb70bf2be4ee6920008cc30d98c"
dependencies = [
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -145,7 +145,7 @@ dependencies = [
"num-bigint 0.4.0",
"num-traits 0.2.14",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -205,7 +205,7 @@ checksum = "5ac3d78c750b01f5df5b2e76d106ed31487a93b3868f14a7f0eb3a74f45e1d8a"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -653,7 +653,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d"
dependencies = [
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -664,7 +664,7 @@ checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -765,7 +765,7 @@ checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
"synstructure",
]
@ -809,7 +809,7 @@ dependencies = [
"num-traits 0.2.14",
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -1059,9 +1059,9 @@ dependencies = [
[[package]]
name = "js-sys"
version = "0.3.49"
version = "0.3.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc15e39392125075f60c95ba416f5381ff6c3a948ff02ab12464715adf56c821"
checksum = "2d99f9e3e84b8f67f846ef5b4cbbc3b1c29f6c759fcbce6f01aa0e73d932a24c"
dependencies = [
"wasm-bindgen",
]
@ -1074,9 +1074,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.91"
version = "0.2.92"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7"
checksum = "56d855069fafbb9b344c0f962150cd2c1187975cb1c22c1522c240d8c4986714"
[[package]]
name = "libgit2-sys"
@ -1370,7 +1370,7 @@ dependencies = [
"pest_meta",
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -1773,7 +1773,7 @@ checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
]
[[package]]
@ -1866,9 +1866,9 @@ dependencies = [
[[package]]
name = "syn"
version = "1.0.64"
version = "1.0.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f"
checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
@ -1883,7 +1883,7 @@ checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
"unicode-xid 0.2.1",
]
@ -2106,9 +2106,9 @@ checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
[[package]]
name = "wasm-bindgen"
version = "0.2.72"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fe8f61dba8e5d645a4d8132dc7a0a66861ed5e1045d2c0ed940fab33bac0fbe"
checksum = "83240549659d187488f91f33c0f8547cbfef0b2088bc470c116d1d260ef623d9"
dependencies = [
"cfg-if 1.0.0",
"wasm-bindgen-macro",
@ -2116,24 +2116,24 @@ dependencies = [
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.72"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "046ceba58ff062da072c7cb4ba5b22a37f00a302483f7e2a6cdc18fedbdc1fd3"
checksum = "ae70622411ca953215ca6d06d3ebeb1e915f0f6613e3b495122878d7ebec7dae"
dependencies = [
"bumpalo",
"lazy_static",
"log",
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.22"
version = "0.4.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73157efb9af26fb564bb59a009afd1c7c334a44db171d280690d0c3faaec3468"
checksum = "81b8b767af23de6ac18bf2168b690bed2902743ddf0fb39252e36f9e2bfc63ea"
dependencies = [
"cfg-if 1.0.0",
"js-sys",
@ -2143,9 +2143,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.72"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ef9aa01d36cda046f797c57959ff5f3c615c9cc63997a8d545831ec7976819b"
checksum = "3e734d91443f177bfdb41969de821e15c516931c3c3db3d318fa1b68975d0f6f"
dependencies = [
"quote 1.0.9",
"wasm-bindgen-macro-support",
@ -2153,28 +2153,28 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.72"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96eb45c1b2ee33545a813a92dbb53856418bf7eb54ab34f7f7ff1448a5b3735d"
checksum = "d53739ff08c8a68b0fdbcd54c372b8ab800b1449ab3c9d706503bc7dd1621b2c"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.72"
version = "0.2.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7148f4696fb4960a346eaa60bbfb42a1ac4ebba21f750f75fc1375b098d5ffa"
checksum = "d9a543ae66aa233d14bb765ed9af4a33e81b8b58d1584cf1b47ff8cd0b9e4489"
[[package]]
name = "wasm-bindgen-test"
version = "0.3.22"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f002ea97b5abdb19aafd48cbb5a0a7f6931cf36ea05a0a46ccc95d9f4c2cf43"
checksum = "e972e914de63aa53bd84865e54f5c761bd274d48e5be3a6329a662c0386aa67a"
dependencies = [
"console_error_panic_hook",
"js-sys",
@ -2186,9 +2186,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-test-macro"
version = "0.3.22"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10a6c0bd3933daf64c78fc25a7452530f79fa7e21f77fa03d608d1e988a66735"
checksum = "ea6153a8f9bf24588e9f25c87223414fff124049f68d3a442a0f0eab4768a8b6"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
@ -2196,9 +2196,9 @@ dependencies = [
[[package]]
name = "web-sys"
version = "0.3.49"
version = "0.3.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fe19d70f5dacc03f6e46777213facae5ac3801575d56ca6cbd4c93dcd12310"
checksum = "a905d57e488fec8861446d3393670fb50d27a262344013181c2cdf9fff5481be"
dependencies = [
"js-sys",
"wasm-bindgen",
@ -2252,7 +2252,7 @@ checksum = "c3f369ddb18862aba61aa49bf31e74d29f0f162dec753063200e1dc084345d16"
dependencies = [
"proc-macro2 1.0.24",
"quote 1.0.9",
"syn 1.0.64",
"syn 1.0.67",
"synstructure",
]

7
asserts.zok Normal file
View file

@ -0,0 +1,7 @@
def id<N>() -> u32:
return N
def main():
assert(id::<5>() == 5)
assert(id::<6>() == 6)
return

View file

@ -0,0 +1 @@
Introduce constant generics for `u32` values. Introduce literal inference

View file

@ -0,0 +1 @@
Make embed functions generic, enabling unpacking to any width at minimal cost

11
example.zok Normal file
View file

@ -0,0 +1,11 @@
def foo<N>(field[N] x) -> field[N]:
return x
def bar<N>(field[N] x) -> field[N]:
field[N] r = x
return r
def main(field[3] x) -> field[2]:
field[2] z = foo(x)[0..2]
return bar(z)

13
scripts/benchmark.sh Executable file
View file

@ -0,0 +1,13 @@
#!/bin/bash
# Usage: benchmark.sh <command>
# For MacOS: install gtime with homebrew `brew install gnu-time`
cmd=$*
format="mem=%KK rss=%MK elapsed=%E cpu=%P cpu.sys=%S inputs=%I outputs=%O"
if command -v gtime; then
gtime -f "$format" $cmd
else
/usr/bin/time -f "$format" $cmd
fi

View file

@ -17,7 +17,7 @@ impl<T: From<usize>> Encode<T> for Inputs<T> {
use std::collections::BTreeMap;
use std::convert::TryFrom;
use std::fmt;
use zokrates_core::typed_absy::{Type, UBitwidth};
use zokrates_core::typed_absy::types::{ConcreteType, UBitwidth};
use zokrates_field::Field;
@ -94,18 +94,20 @@ impl<T: Field> fmt::Display for Value<T> {
}
impl<T: Field> Value<T> {
fn check(self, ty: Type) -> Result<CheckedValue<T>, String> {
fn check(self, ty: ConcreteType) -> Result<CheckedValue<T>, String> {
match (self, ty) {
(Value::Field(f), Type::FieldElement) => Ok(CheckedValue::Field(f)),
(Value::U8(f), Type::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)),
(Value::U16(f), Type::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)),
(Value::U32(f), Type::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)),
(Value::Boolean(b), Type::Boolean) => Ok(CheckedValue::Boolean(b)),
(Value::Array(a), Type::Array(array_type)) => {
if a.len() != array_type.size {
(Value::Field(f), ConcreteType::FieldElement) => Ok(CheckedValue::Field(f)),
(Value::U8(f), ConcreteType::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)),
(Value::U16(f), ConcreteType::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)),
(Value::U32(f), ConcreteType::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)),
(Value::Boolean(b), ConcreteType::Boolean) => Ok(CheckedValue::Boolean(b)),
(Value::Array(a), ConcreteType::Array(array_type)) => {
let size = array_type.size;
if a.len() != size as usize {
Err(format!(
"Expected array of size {}, found array of size {}",
array_type.size,
size,
a.len()
))
} else {
@ -116,15 +118,16 @@ impl<T: Field> Value<T> {
Ok(CheckedValue::Array(a))
}
}
(Value::Struct(mut s), Type::Struct(members)) => {
if s.len() != members.len() {
(Value::Struct(mut s), ConcreteType::Struct(struc)) => {
if s.len() != struc.members_count() {
Err(format!(
"Expected {} member(s), found {}",
members.len(),
struc.members_count(),
s.len()
))
} else {
let s = members
let s = struc
.members
.into_iter()
.map(|member| {
s.remove(&member.id)
@ -167,7 +170,7 @@ impl<T: From<usize>> Encode<T> for CheckedValue<T> {
}
impl<T: Field> Decode<T> for CheckedValues<T> {
type Expected = Vec<Type>;
type Expected = Vec<ConcreteType>;
fn decode(raw: Vec<T>, expected: Self::Expected) -> Self {
CheckedValues(
@ -185,23 +188,24 @@ impl<T: Field> Decode<T> for CheckedValues<T> {
}
impl<T: Field> Decode<T> for CheckedValue<T> {
type Expected = Type;
type Expected = ConcreteType;
fn decode(raw: Vec<T>, expected: Self::Expected) -> Self {
let mut raw = raw;
match expected {
Type::FieldElement => CheckedValue::Field(raw.pop().unwrap()),
Type::Uint(UBitwidth::B8) => CheckedValue::U8(
ConcreteType::Int => unreachable!(),
ConcreteType::FieldElement => CheckedValue::Field(raw.pop().unwrap()),
ConcreteType::Uint(UBitwidth::B8) => CheckedValue::U8(
u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
),
Type::Uint(UBitwidth::B16) => CheckedValue::U16(
ConcreteType::Uint(UBitwidth::B16) => CheckedValue::U16(
u16::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
),
Type::Uint(UBitwidth::B32) => CheckedValue::U32(
ConcreteType::Uint(UBitwidth::B32) => CheckedValue::U32(
u32::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(),
),
Type::Boolean => {
ConcreteType::Boolean => {
let v = raw.pop().unwrap();
CheckedValue::Boolean(if v == 0.into() {
false
@ -211,12 +215,12 @@ impl<T: Field> Decode<T> for CheckedValue<T> {
unreachable!()
})
}
Type::Array(array_type) => CheckedValue::Array(
ConcreteType::Array(array_type) => CheckedValue::Array(
raw.chunks(array_type.ty.get_primitive_count())
.map(|c| CheckedValue::decode(c.to_vec(), *array_type.ty.clone()))
.collect(),
),
Type::Struct(members) => CheckedValue::Struct(
ConcreteType::Struct(members) => CheckedValue::Struct(
members
.into_iter()
.scan(0, |state, member| {
@ -247,9 +251,9 @@ impl<T: Field> TryFrom<serde_json::Value> for Values<T> {
match v {
serde_json::Value::Array(a) => a
.into_iter()
.map(|v| Value::try_from(v))
.map(Value::try_from)
.collect::<Result<_, _>>()
.map(|v| Values(v)),
.map(Values),
v => Err(format!("Expected an array of values, found `{}`", v)),
}
}
@ -259,20 +263,22 @@ impl<T: Field> TryFrom<serde_json::Value> for Value<T> {
type Error = String;
fn try_from(v: serde_json::Value) -> Result<Value<T>, Self::Error> {
match v {
serde_json::Value::String(s) => T::try_from_dec_str(&s)
.map(|v| Value::Field(v))
.or_else(|_| match s.len() {
4 => u8::from_str_radix(&s[2..], 16)
.map(|v| Value::U8(v))
.map_err(|_| format!("Expected u8 value, found {}", s)),
6 => u16::from_str_radix(&s[2..], 16)
.map(|v| Value::U16(v))
.map_err(|_| format!("Expected u16 value, found {}", s)),
10 => u32::from_str_radix(&s[2..], 16)
.map(|v| Value::U32(v))
.map_err(|_| format!("Expected u32 value, found {}", s)),
_ => Err(format!("Cannot parse {} to any type", s)),
}),
serde_json::Value::String(s) => {
T::try_from_dec_str(&s)
.map(Value::Field)
.or_else(|_| match s.len() {
4 => u8::from_str_radix(&s[2..], 16)
.map(Value::U8)
.map_err(|_| format!("Expected u8 value, found {}", s)),
6 => u16::from_str_radix(&s[2..], 16)
.map(Value::U16)
.map_err(|_| format!("Expected u16 value, found {}", s)),
10 => u32::from_str_radix(&s[2..], 16)
.map(Value::U32)
.map_err(|_| format!("Expected u32 value, found {}", s)),
_ => Err(format!("Cannot parse {} to any type", s)),
})
}
serde_json::Value::Bool(b) => Ok(Value::Boolean(b)),
serde_json::Value::Number(n) => Err(format!(
"Value `{}` isn't allowed, did you mean `\"{}\"`?",
@ -280,14 +286,14 @@ impl<T: Field> TryFrom<serde_json::Value> for Value<T> {
)),
serde_json::Value::Array(a) => a
.into_iter()
.map(|v| Value::try_from(v))
.map(Value::try_from)
.collect::<Result<_, _>>()
.map(|v| Value::Array(v)),
.map(Value::Array),
serde_json::Value::Object(o) => o
.into_iter()
.map(|(k, v)| Value::try_from(v).map(|v| (k, v)))
.collect::<Result<Map<_, _>, _>>()
.map(|v| Value::Struct(v)),
.map(Value::Struct),
v => Err(format!("Value `{}` isn't allowed", v)),
}
}
@ -320,10 +326,13 @@ impl<T: Field> Into<serde_json::Value> for CheckedValues<T> {
fn parse<T: Field>(s: &str) -> Result<Values<T>, Error> {
let json_values: serde_json::Value =
serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?;
Values::try_from(json_values).map_err(|e| Error::Conversion(e))
Values::try_from(json_values).map_err(Error::Conversion)
}
pub fn parse_strict<T: Field>(s: &str, types: Vec<Type>) -> Result<CheckedValues<T>, Error> {
pub fn parse_strict<T: Field>(
s: &str,
types: Vec<ConcreteType>,
) -> Result<CheckedValues<T>, Error> {
let parsed = parse(s)?;
if parsed.0.len() != types.len() {
return Err(Error::Type(format!(
@ -338,7 +347,7 @@ pub fn parse_strict<T: Field>(s: &str, types: Vec<Type>) -> Result<CheckedValues
.zip(types.into_iter())
.map(|(v, ty)| v.check(ty))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| Error::Type(e))?;
.map_err(Error::Type)?;
Ok(CheckedValues(checked))
}
@ -403,14 +412,19 @@ mod tests {
mod strict {
use super::*;
use zokrates_core::typed_absy::types::{StructMember, StructType};
use zokrates_core::typed_absy::types::{
ConcreteStructMember, ConcreteStructType, ConcreteType,
};
#[test]
fn fields() {
let s = r#"["1", "2"]"#;
assert_eq!(
parse_strict::<Bn128Field>(s, vec![Type::FieldElement, Type::FieldElement])
.unwrap(),
parse_strict::<Bn128Field>(
s,
vec![ConcreteType::FieldElement, ConcreteType::FieldElement]
)
.unwrap(),
CheckedValues(vec![
CheckedValue::Field(1.into()),
CheckedValue::Field(2.into())
@ -422,7 +436,8 @@ mod tests {
fn bools() {
let s = "[true, false]";
assert_eq!(
parse_strict::<Bn128Field>(s, vec![Type::Boolean, Type::Boolean]).unwrap(),
parse_strict::<Bn128Field>(s, vec![ConcreteType::Boolean, ConcreteType::Boolean])
.unwrap(),
CheckedValues(vec![
CheckedValue::Boolean(true),
CheckedValue::Boolean(false)
@ -434,7 +449,11 @@ mod tests {
fn array() {
let s = "[[true, false]]";
assert_eq!(
parse_strict::<Bn128Field>(s, vec![Type::array(Type::Boolean, 2)]).unwrap(),
parse_strict::<Bn128Field>(
s,
vec![ConcreteType::array((ConcreteType::Boolean, 2usize))]
)
.unwrap(),
CheckedValues(vec![CheckedValue::Array(vec![
CheckedValue::Boolean(true),
CheckedValue::Boolean(false)
@ -448,10 +467,13 @@ mod tests {
assert_eq!(
parse_strict::<Bn128Field>(
s,
vec![Type::Struct(StructType::new(
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![StructMember::new("a".into(), Type::FieldElement)]
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))]
)
.unwrap(),
@ -466,10 +488,13 @@ mod tests {
assert_eq!(
parse_strict::<Bn128Field>(
s,
vec![Type::Struct(StructType::new(
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![StructMember::new("a".into(), Type::FieldElement)]
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))]
)
.unwrap_err(),
@ -480,10 +505,13 @@ mod tests {
assert_eq!(
parse_strict::<Bn128Field>(
s,
vec![Type::Struct(StructType::new(
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![StructMember::new("a".into(), Type::FieldElement)]
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))]
)
.unwrap_err(),
@ -494,10 +522,13 @@ mod tests {
assert_eq!(
parse_strict::<Bn128Field>(
s,
vec![Type::Struct(StructType::new(
vec![ConcreteType::Struct(ConcreteStructType::new(
"".into(),
"".into(),
vec![StructMember::new("a".into(), Type::FieldElement)]
vec![ConcreteStructMember::new(
"a".into(),
ConcreteType::FieldElement
)]
))]
)
.unwrap_err(),

View file

@ -12,6 +12,7 @@
- [Control flow](language/control_flow.md)
- [Imports](language/imports.md)
- [Comments](language/comments.md)
- [Generics](language/generics.md)
- [Macros](language/macros.md)
- [Toolbox](toolbox/index.md)

View file

@ -12,6 +12,12 @@ Arguments are passed by value.
{{#include ../../../zokrates_cli/examples/book/side_effects.zok}}
```
Generic paramaters, if any, must be compile-time constants. They are inferred by the compiler if that is possible, but can also be provided explicitly.
```zokrates
{{#include ../../../zokrates_cli/examples/book/generic_call.zok}}
```
### If-expressions
An if-expression allows you to branch your code depending on a boolean condition.
@ -28,7 +34,7 @@ For loops are available with the following syntax:
{{#include ../../../zokrates_cli/examples/book/for.zok}}
```
The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs.
The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. They can depend on generic parameters.
### Assertions

View file

@ -7,7 +7,14 @@ A function has to be declared at the top level before it is called.
```
A function's signature has to be explicitly provided.
Functions can return many values by providing them as a comma-separated list.
A function can be generic over any number of values of type `u32`.
```zokrates
{{#include ../../../zokrates_cli/examples/book/generic_function_declaration.zok}}
```
Functions can return multiple values by providing them as a comma-separated list.
```zokrates
{{#include ../../../zokrates_cli/examples/book/multi_return.zok}}

View file

@ -0,0 +1,7 @@
## Generics
ZoKrates supports code that is generic over constants of the `u32` type. No specific keyword is used: the compiler determines if the generic parameters are indeed constant at compile time. Here's an example of generic code in ZoKrates:
```zokrates
{{#include ../../../zokrates_cli/examples/book/generics.zok}}
```

View file

@ -8,7 +8,7 @@ ZoKrates currently exposes two primitive types and two complex types:
This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers.
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/reference/proving_schemes.html#alt_bn128) curve supported by Ethereum.
As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/toolbox/proving_schemes.html#alt_bn128) curve supported by Ethereum.
While `field` values mostly behave like unsigned integers, one should keep in mind that they overflow at `p` and not some power of 2, so that we have:
@ -32,13 +32,23 @@ Similarly to booleans, unsigned integer inputs of the main function only accept
The division operation calculates the standard floor division for integers. The `%` operand can be used to obtain the remainder.
### Numeric inference
In the case of decimal literals like `42`, the compiler tries to find the appropriate type (`field`, `u8`, `u16` or `u32`) depending on the context. If it cannot converge to a single option, an error is returned. This means that there is no default type for decimal literals.
All operations between literals have the semantics of the infered type.
```zokrates
{{#include ../../../zokrates_cli/examples/book/numeric_inference.zok}}
```
## Complex Types
ZoKrates provides two complex types: arrays and structs.
### Arrays
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time.
ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. For more details on generic array sizes, see [constant generics](/language/generics.html)
Arrays can contain elements of any type and have arbitrary dimensions.
The following example code shows examples of how to use arrays:

View file

@ -11,5 +11,5 @@ fn export_stdlib() {
let out_dir = env::var("OUT_DIR").unwrap();
let mut options = CopyOptions::new();
options.overwrite = true;
copy_items(&vec!["tests/contract"], out_dir, &options).unwrap();
copy_items(&["tests/contract"], out_dir, &options).unwrap();
}

View file

@ -1,7 +1,7 @@
def main() -> field:
field[3] a = [1, 2, 3]
field c = 0
for field i in 0..3 do
def main() -> u32:
u32[3] a = [1, 2, 3]
u32 c = 0
for u32 i in 0..3 do
c = c + a[i]
endfor
return c

View file

@ -1,7 +1,7 @@
def main() -> (field[3]):
field[3] a = [1, 2, 3]
field[3] c = [4, 5, 6]
for field i in 0..3 do
def main() -> (u32[3]):
u32[3] a = [1, 2, 3]
u32[3] c = [4, 5, 6]
for u32 i in 0..3 do
c[i] = c[i] + a[i]
endfor
return c

View file

@ -3,7 +3,7 @@ def main(bool[3] a) -> (field[3]):
a[1] = true || a[2]
a[2] = a[0]
field[3] result = [0; 3]
for field i in 0..3 do
for u32 i in 0..3 do
result[i] = if a[i] then 33 else 0 fi
endfor
return result

View file

@ -1,9 +1,9 @@
def main(field[2][2][2] cube) -> field:
field res = 0
for field i in 0..2 do
for field j in 0..2 do
for field k in 0..2 do
for u32 i in 0..2 do
for u32 j in 0..2 do
for u32 k in 0..2 do
res = res + cube[i][j][k]
endfor
endfor

View file

@ -1,2 +1,2 @@
def main(field index, field[5] array) -> field:
def main(u32 index, field[5] array) -> field:
return array[index]

View file

@ -1,4 +1,4 @@
def main(field[10][10][10] a, field i, field j, field k) -> (field[3]):
def main(field[10][10][10] a, u32 i, u32 j, u32 k) -> (field[3]):
a[i][j][k] = 42
field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
return b[0]

View file

@ -0,0 +1,4 @@
def main(field a) -> field[4]:
u32 SIZE = 4
field[SIZE] res = [a; SIZE]
return res

View file

@ -1,6 +1,6 @@
def slice32from(field offset, field[2048] input) -> (field[32]):
def slice32from(u32 offset, field[2048] input) -> (field[32]):
field[32] result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
for field i in 0..32 do
for u32 i in 0..32 do
result[i] = input[offset + i]
endfor
return result

View file

@ -1,4 +1,4 @@
def get(field[32] array, field index) -> field:
def get(field[32] array, u32 index) -> field:
return array[index]
def main() -> field:

View file

@ -5,4 +5,6 @@ def main() -> field:
field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4
field[2] d = a[1..3] // initialize an array copying a slice from `a`
bool[3] e = [true, true || false, true] // initialize a boolean array
u32 SIZE = 3
field[SIZE] f = [1, 2, 3] // initialize a field array with a size that's a compile-time constant
return a[0] + b[1] + c[2]

View file

@ -1,3 +1,3 @@
def main() -> ():
assert(1 == 2)
assert(1f == 2f)
return

View file

@ -1,7 +1,7 @@
def main() -> field:
field res = 0
for field i in 0..4 do
for field j in i..5 do
def main() -> u32:
u32 res = 0
for u32 i in 0..4 do
for u32 j in i..5 do
res = res + i
endfor
endfor

View file

@ -1,6 +1,6 @@
def main() -> field:
field a = 0
for field i in 0..5 do
def main() -> u32:
u32 a = 0
for u32 i in 0..5 do
a = a + i
endfor
// return i <- not allowed

View file

@ -1,5 +1,5 @@
def foo() -> field:
return 1
def foo(field a, field b) -> field:
return a + b
def main() -> field:
return foo()
return foo(1, 2)

View file

@ -0,0 +1,7 @@
def foo<N, P>() -> field[P]:
return [42; P]
def main() -> field[2]:
// `P` is inferred from the declaration of `res`, while `N` is provided explicitly
field[2] res = foo::<3, _>()
return res

View file

@ -0,0 +1,6 @@
def foo<N>() -> field[N]:
return [42; N]
def main() -> field[2]:
field[2] res = foo()
return res

View file

@ -0,0 +1,9 @@
def sum<N>(field[N] a) -> field:
field res = 0
for u32 i in 0..N do
res = res + a[i]
endfor
return res
def main(field[3] a) -> field:
return sum(a)

View file

@ -1,7 +1,7 @@
def main() -> field:
field a = 2
// field a = 3 <- not allowed
for field i in 0..5 do
for u32 i in 0..5 do
// field a = 7 <- not allowed
endfor
return a

View file

@ -0,0 +1,9 @@
def main():
// `255` is infered to `255f`, and the addition happens between field elements
assert(255 + 1f == 256)
// `255` is infered to `255u8`, and the addition happens between u8
// This causes an overflow
assert(255 + 1u8 == 0)
return

View file

@ -0,0 +1,5 @@
import "EMBED/u8_to_bits" as u8_to_bits
def main(u8 x):
bool[32] b = u8_to_bits(x) // note the incorrect array length on the left
return

View file

@ -0,0 +1,8 @@
def foo<N>(field[N] a) -> bool:
field[3] b = a
return true
def main(field[1] a):
assert(foo(a))
return

View file

@ -0,0 +1,3 @@
def main():
assert([1] == [1, 2])
return

View file

@ -0,0 +1,2 @@
def main<P>(field[P] a):
return

View file

@ -0,0 +1,5 @@
def foo<P>(field[P] a, field[P] b) -> field:
return 42
def main() -> field:
return foo([1, 2], [1])

View file

@ -0,0 +1,3 @@
def main():
assert([[1]] == [1, 2])
return

View file

@ -0,0 +1,2 @@
def main<P>():
return

View file

@ -1,4 +1,4 @@
def main() -> field:
for field i in 0..5 do
for u32 i in 0..5 do
endfor
return i

View file

@ -14,6 +14,6 @@ def main(field a) -> field:
assert(2 * b == a * 12 + 60)
field c = 7 * (b + a)
assert(isEqual(c, 7 * b + 7 * a))
field k = if [1, 2] == [3, 4] then 1 else 3 fi
field k = if [1f, 2] == [3f, 4] then 1 else 3 fi
assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}])
return b + c

View file

@ -1,6 +1,10 @@
def bound(field x) -> u32:
return 41 + 1
def main(field a) -> field:
field x = 7
for field i in 0..10 do
x = x + 1
for u32 i in 0..bound(x) do
// x = x + a
x = x + a
endfor

View file

@ -4,14 +4,14 @@ def lt(field a,field b) -> bool:
def cutoff() -> field:
return 31337
def getThing(field index) -> field:
def getThing(u32 index) -> field:
field[6] a = [13, 23, 43, 53, 73, 83]
return a[index]
def cubeThing(field thing) -> field:
return thing**3
def main(field index) -> bool:
def main(u32 index) -> bool:
field thing = getThing(index)
thing = cubeThing(thing)
return lt(cutoff(), thing)

View file

@ -4,9 +4,6 @@ import "ecc/babyjubjubParams" as context
from "ecc/babyjubjubParams" import BabyJubJubParams
import "hashes/utils/256bitsDirectionHelper" as multiplex
def multiplex(bool selector, u32[8] left, u32[8] right) -> (u32[8]):
return if selector then right else left fi
// Merke-Tree inclusion proof for tree depth 3 using SNARK efficient pedersen hashes
// directionSelector=> true if current digest is on the rhs of the hash

View file

@ -1,9 +1,11 @@
import "utils/casts/u32_to_field" as to_field
// Binomial Coeffizient, n!/(k!*(n-k)!).
def fac(field x) -> field:
field f = 1
field counter = 0
for field i in 1..100 do
f = if counter == x then f else f * i fi
for u32 i in 1..100 do
f = if counter == x then f else f * to_field(i) fi
counter = if counter == x then counter else counter + 1 fi
endfor
return f

View file

@ -2,7 +2,7 @@ def main() -> field:
field a = 1 + 2 + 3
field b = if 1 < a then 3 else a + 3 fi
field c = if b + a == 2 then 1 else b fi
for field e in 0..2 do
for u32 e in 0..2 do
field g = 4
c = c + g
endfor

View file

@ -1,3 +1,3 @@
def main() -> field:
field a = 2
return 2**(a**2 + 2)
u32 a = 2
return 2**(a * 2 + 2)

View file

@ -32,26 +32,26 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva
bool res = true
// go through the whole grid and check that all elements are valid
for field i in 0..4 do
for field j in 0..4 do
for u32 i in 0..4 do
for u32 j in 0..4 do
res = res && validateInput(a[i][j])
endfor
endfor
// go through the 4 2x2 boxes and check that they do not contain duplicates
for field i in 0..1 do
for field j in 0..1 do
for u32 i in 0..1 do
for u32 j in 0..1 do
res = res && checkNoDuplicates(a[2*i][2*i], a[2*i][2*i + 1], a[2*i + 1][2*i], a[2*i + 1][2*i + 1])
endfor
endfor
// go through the 4 rows and check that they do not contain duplicates
for field i in 0..4 do
for u32 i in 0..4 do
res = res && checkNoDuplicates(a[i][0], a[i][1], a[i][2], a[i][3])
endfor
// go through the 4 columns and check that they do not contain duplicates
for field j in 0..4 do
for u32 j in 0..4 do
res = res && checkNoDuplicates(a[0][j], a[1][j], a[2][j], a[3][j])
endfor

View file

@ -10,6 +10,6 @@ def isWaldo(field a, field p, field q) -> bool:
return a == p * q
// define all
def main(field[3] a, private field index, private field p, private field q) -> bool:
def main(field[3] a, private u32 index, private field p, private field q) -> bool:
// prover provides the index of Waldo
return isWaldo(a[index], p, q)

View file

@ -32,10 +32,13 @@ fn cli() -> Result<(), String> {
compile::subcommand(),
check::subcommand(),
compute_witness::subcommand(),
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
setup::subcommand(),
export_verifier::subcommand(),
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
generate_proof::subcommand(),
print_proof::subcommand(),
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
verify::subcommand()])
.get_matches();
@ -140,7 +143,7 @@ mod tests {
let interpreter = ir::Interpreter::default();
let _ = interpreter
.execute(&artifacts.prog(), &vec![Bn128Field::from(0)])
.execute(&artifacts.prog(), &[Bn128Field::from(0)])
.unwrap();
}
}
@ -169,7 +172,7 @@ mod tests {
let interpreter = ir::Interpreter::default();
let res = interpreter.execute(&artifacts.prog(), &vec![Bn128Field::from(0)]);
let res = interpreter.execute(&artifacts.prog(), &[Bn128Field::from(0)]);
assert!(res.is_err());
}

View file

@ -19,6 +19,7 @@ lazy_static! {
.unwrap();
}
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") {
if cfg!(feature = "ark") {
if cfg!(feature = "bellman") {
@ -26,27 +27,21 @@ pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") {
} else {
&[LIBSNARK, ARK]
}
} else if cfg!(feature = "bellman") {
&[BELLMAN, LIBSNARK]
} else {
if cfg!(feature = "bellman") {
&[BELLMAN, LIBSNARK]
} else {
&[LIBSNARK]
}
&[LIBSNARK]
}
} else if cfg!(feature = "ark") {
if cfg!(feature = "bellman") {
&[BELLMAN, ARK]
} else {
&[ARK]
}
} else if cfg!(feature = "bellman") {
&[BELLMAN]
} else {
if cfg!(feature = "ark") {
if cfg!(feature = "bellman") {
&[BELLMAN, ARK]
} else {
&[ARK]
}
} else {
if cfg!(feature = "bellman") {
&[BELLMAN]
} else {
&[]
}
}
&[]
};
pub const BN128: &str = "bn128";

View file

@ -69,7 +69,7 @@ fn cli_check<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
format!(
"{}:{}",
file.strip_prefix(std::env::current_dir().unwrap())
.unwrap_or(file.as_path())
.unwrap_or_else(|_| file.as_path())
.display(),
e.value()
)

View file

@ -92,9 +92,9 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
let fmt_error = |e: &CompileError| {
let file = e.file().canonicalize().unwrap();
format!(
"{}:{}",
"{}: {}",
file.strip_prefix(std::env::current_dir().unwrap())
.unwrap_or(file.as_path())
.unwrap_or_else(|_| file.as_path())
.display(),
e.value()
)
@ -154,7 +154,7 @@ fn cli_compile<T: Field>(sub_matches: &ArgMatches) -> Result<(), String> {
.map_err(|why| format!("Couldn't create {}: {}", hr_output_path.display(), why))?;
let mut hrofb = BufWriter::new(hr_output_file);
write!(&mut hrofb, "{}\n", program_flattened)
writeln!(&mut hrofb, "{}", program_flattened)
.map_err(|_| "Unable to write data to file".to_string())?;
hrofb
.flush()

View file

@ -8,7 +8,7 @@ use zokrates_abi::Encode;
use zokrates_core::ir;
use zokrates_core::ir::ProgEnum;
use zokrates_core::typed_absy::abi::Abi;
use zokrates_core::typed_absy::{Signature, Type};
use zokrates_core::typed_absy::types::{ConcreteSignature, ConcreteType};
use zokrates_field::Field;
pub fn subcommand() -> App<'static, 'static> {
@ -106,9 +106,12 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
abi.signature()
}
false => Signature::new()
.inputs(vec![Type::FieldElement; ir_prog.main.arguments.len()])
.outputs(vec![Type::FieldElement; ir_prog.main.returns.len()]),
false => ConcreteSignature::new()
.inputs(vec![
ConcreteType::FieldElement;
ir_prog.main.arguments.len()
])
.outputs(vec![ConcreteType::FieldElement; ir_prog.main.returns.len()]),
};
use zokrates_abi::Inputs;
@ -123,8 +126,8 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
a.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string()))
.collect::<Result<Vec<_>, _>>()
})
.unwrap_or(Ok(vec![]))
.map(|v| Inputs::Raw(v))
.unwrap_or_else(|| Ok(vec![]))
.map(Inputs::Raw)
}
// take stdin arguments
true => {
@ -137,7 +140,7 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
use zokrates_abi::parse_strict;
parse_strict(&input, signature.inputs)
.map(|parsed| Inputs::Abi(parsed))
.map(Inputs::Abi)
.map_err(|why| why.to_string())
}
Err(_) => Err(String::from("???")),
@ -148,10 +151,10 @@ fn cli_compute<T: Field>(ir_prog: ir::Prog<T>, sub_matches: &ArgMatches) -> Resu
Ok(_) => {
input.retain(|x| x != '\n');
input
.split(" ")
.split(' ')
.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string()))
.collect::<Result<Vec<_>, _>>()
.map(|v| Inputs::Raw(v))
.map(Inputs::Raw)
}
Err(_) => Err(String::from("???")),
},

View file

@ -174,7 +174,7 @@ fn cli_generate_proof<T: Field, S: Scheme<T>, B: Backend<T, S>>(
.write(proof.as_bytes())
.map_err(|why| format!("Couldn't write to {}: {}", proof_path.display(), why))?;
println!("Proof:\n{}", format!("{}", proof));
println!("Proof:\n{}", proof);
Ok(())
}

View file

@ -1,7 +1,6 @@
pub mod check;
pub mod compile;
pub mod compute_witness;
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
pub mod export_verifier;
#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))]
pub mod generate_proof;

View file

@ -1,9 +1,11 @@
import "utils/casts/u32_to_field" as to_field
// Binomial Coeffizient, n!/(k!*(n-k)!).
def fac(field x) -> field:
field f = 1
field counter = 0
for field i in 1..100 do
f = if counter == x then f else f * i fi
for u32 i in 1..100 do
f = if counter == x then f else f * to_field(i) fi
counter = if counter == x then counter else counter + 1 fi
endfor
return f

View file

@ -3,7 +3,7 @@ extern crate serde_json;
#[cfg(test)]
mod integration {
use assert_cli;
use serde_json::from_reader;
use std::fs;
use std::fs::File;
@ -147,11 +147,11 @@ mod integration {
.map_err(|why| why.to_string())
.unwrap();
let signature = abi.signature().clone();
let signature = abi.signature();
let inputs_abi: zokrates_abi::Inputs<zokrates_field::Bn128Field> =
parse_strict(&json_input_str, signature.inputs)
.map(|parsed| zokrates_abi::Inputs::Abi(parsed))
.map(zokrates_abi::Inputs::Abi)
.map_err(|why| why.to_string())
.unwrap();
let inputs_raw: Vec<_> = inputs_abi
@ -169,7 +169,7 @@ mod integration {
inline_witness_path.to_str().unwrap(),
];
if inputs_raw.len() > 0 {
if !inputs_raw.is_empty() {
compute_inline.push("-a");
for arg in &inputs_raw {
@ -202,7 +202,7 @@ mod integration {
assert_eq!(inline_witness, witness);
for line in expected_witness.as_str().split("\n") {
for line in expected_witness.as_str().split('\n') {
assert!(
witness.contains(line),
"Witness generation failed for {}\n\nLine \"{}\" not found in witness",

View file

@ -1,7 +1,6 @@
use crate::absy;
use crate::imports;
use num::ToPrimitive;
use num_bigint::BigUint;
use zokrates_pest_ast as pest;
@ -10,14 +9,14 @@ impl<'ast> From<pest::File<'ast>> for absy::Module<'ast> {
absy::Module::with_symbols(
prog.structs
.into_iter()
.map(|t| absy::SymbolDeclarationNode::from(t))
.map(absy::SymbolDeclarationNode::from)
.chain(
prog.functions
.into_iter()
.map(|f| absy::SymbolDeclarationNode::from(f)),
.map(absy::SymbolDeclarationNode::from),
),
)
.imports(prog.imports.into_iter().map(|i| absy::ImportNode::from(i)))
.imports(prog.imports.into_iter().map(absy::ImportNode::from))
}
}
@ -31,17 +30,16 @@ impl<'ast> From<pest::ImportDirective<'ast>> for absy::ImportNode<'ast> {
.alias(import.alias.map(|a| a.span.as_str()))
.span(import.span)
}
pest::ImportDirective::From(import) => imports::Import::new(
Some(import.symbol.span.as_str()),
std::path::Path::new(import.source.span.as_str()),
)
.alias(
import
.alias
.map(|a| a.span.as_str())
.or(Some(import.symbol.span.as_str())),
)
.span(import.span),
pest::ImportDirective::From(import) => {
let symbol_str = import.symbol.span.as_str();
imports::Import::new(
Some(import.symbol.span.as_str()),
std::path::Path::new(import.source.span.as_str()),
)
.alias(import.alias.map(|a| a.span.as_str()).or(Some(symbol_str)))
.span(import.span)
}
}
}
}
@ -58,7 +56,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
fields: definition
.fields
.into_iter()
.map(|f| absy::StructDefinitionFieldNode::from(f))
.map(absy::StructDefinitionFieldNode::from)
.collect(),
}
.span(span.clone());
@ -72,7 +70,7 @@ impl<'ast> From<pest::StructDefinition<'ast>> for absy::SymbolDeclarationNode<'a
}
impl<'ast> From<pest::StructField<'ast>> for absy::StructDefinitionFieldNode<'ast> {
fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode {
fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode<'ast> {
use crate::absy::NodeValue;
let span = field.span;
@ -92,6 +90,13 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
let span = function.span;
let signature = absy::UnresolvedSignature::new()
.generics(
function
.generics
.into_iter()
.map(absy::ConstantGenericNode::from)
.collect(),
)
.inputs(
function
.parameters
@ -105,7 +110,7 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
.returns
.clone()
.into_iter()
.map(|r| absy::UnresolvedTypeNode::from(r))
.map(absy::UnresolvedTypeNode::from)
.collect(),
);
@ -115,12 +120,12 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
arguments: function
.parameters
.into_iter()
.map(|a| absy::ParameterNode::from(a))
.map(absy::ParameterNode::from)
.collect(),
statements: function
.statements
.into_iter()
.flat_map(|s| statements_from_statement(s))
.flat_map(statements_from_statement)
.collect(),
signature,
}
@ -134,6 +139,16 @@ impl<'ast> From<pest::Function<'ast>> for absy::SymbolDeclarationNode<'ast> {
}
}
impl<'ast> From<pest::IdentifierExpression<'ast>> for absy::ConstantGenericNode<'ast> {
fn from(g: pest::IdentifierExpression<'ast>) -> absy::ConstantGenericNode<'ast> {
use absy::NodeValue;
let name = g.span.as_str();
name.span(g.span)
}
}
impl<'ast> From<pest::Parameter<'ast>> for absy::ParameterNode<'ast> {
fn from(param: pest::Parameter<'ast>) -> absy::ParameterNode<'ast> {
use crate::absy::NodeValue;
@ -247,7 +262,7 @@ impl<'ast> From<pest::ReturnStatement<'ast>> for absy::StatementNode<'ast> {
expressions: statement
.expressions
.into_iter()
.map(|e| absy::ExpressionNode::from(e))
.map(absy::ExpressionNode::from)
.collect(),
}
.span(statement.span.clone()),
@ -275,7 +290,7 @@ impl<'ast> From<pest::IterationStatement<'ast>> for absy::StatementNode<'ast> {
let statements: Vec<absy::StatementNode<'ast>> = statement
.statements
.into_iter()
.flat_map(|s| statements_from_statement(s))
.flat_map(statements_from_statement)
.collect();
let var = absy::Variable::new(index, ty).span(statement.index.span);
@ -289,7 +304,7 @@ impl<'ast> From<pest::Expression<'ast>> for absy::ExpressionNode<'ast> {
match expression {
pest::Expression::Binary(e) => absy::ExpressionNode::from(e),
pest::Expression::Ternary(e) => absy::ExpressionNode::from(e),
pest::Expression::Constant(e) => absy::ExpressionNode::from(e),
pest::Expression::Literal(e) => absy::ExpressionNode::from(e),
pest::Expression::Identifier(e) => absy::ExpressionNode::from(e),
pest::Expression::Postfix(e) => absy::ExpressionNode::from(e),
pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e),
@ -458,7 +473,7 @@ impl<'ast> From<pest::InlineArrayExpression<'ast>> for absy::ExpressionNode<'ast
array
.expressions
.into_iter()
.map(|e| absy::SpreadOrExpression::from(e))
.map(absy::SpreadOrExpression::from)
.collect(),
)
.span(array.span)
@ -489,13 +504,8 @@ impl<'ast> From<pest::ArrayInitializerExpression<'ast>> for absy::ExpressionNode
use crate::absy::NodeValue;
let value = absy::ExpressionNode::from(*initializer.value);
let count: absy::ExpressionNode<'ast> = absy::ExpressionNode::from(initializer.count);
let count = match count.value {
absy::Expression::FieldConstant(v) => v.to_usize().unwrap(),
_ => unreachable!(),
};
absy::Expression::InlineArray(vec![absy::SpreadOrExpression::Expression(value); count])
.span(initializer.span)
let count = absy::ExpressionNode::from(*initializer.count);
absy::Expression::ArrayInitializer(box value, box count).span(initializer.span)
}
}
@ -529,9 +539,25 @@ impl<'ast> From<pest::PostfixExpression<'ast>> for absy::ExpressionNode<'ast> {
pest::Access::Call(a) => match acc.value {
absy::Expression::Identifier(_) => absy::Expression::FunctionCall(
&id_str,
a.expressions
a.explicit_generics.map(|explicit_generics| {
explicit_generics
.values
.into_iter()
.map(|i| match i {
pest::ConstantGenericValue::Underscore(_) => None,
pest::ConstantGenericValue::Value(v) => {
Some(absy::ExpressionNode::from(v))
}
pest::ConstantGenericValue::Identifier(i) => {
Some(absy::Expression::Identifier(i.span.as_str()).span(i.span))
}
})
.collect()
}),
a.arguments
.expressions
.into_iter()
.map(|e| absy::ExpressionNode::from(e))
.map(absy::ExpressionNode::from)
.collect(),
),
e => unimplemented!("only identifiers are callable, found \"{}\"", e),
@ -548,29 +574,63 @@ impl<'ast> From<pest::PostfixExpression<'ast>> for absy::ExpressionNode<'ast> {
}
}
impl<'ast> From<pest::ConstantExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::ConstantExpression<'ast>) -> absy::ExpressionNode<'ast> {
impl<'ast> From<pest::DecimalLiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::DecimalLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue;
match expression.suffix {
Some(suffix) => match suffix {
pest::DecimalSuffix::Field(_) => absy::Expression::FieldConstant(
BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(),
),
pest::DecimalSuffix::U32(_) => absy::Expression::U32Constant(
u32::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
),
pest::DecimalSuffix::U16(_) => absy::Expression::U16Constant(
u16::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
),
pest::DecimalSuffix::U8(_) => absy::Expression::U8Constant(
u8::from_str_radix(&expression.value.span.as_str(), 10).unwrap(),
),
}
.span(expression.span),
None => absy::Expression::IntConstant(
BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(),
)
.span(expression.span),
}
}
}
impl<'ast> From<pest::HexLiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::HexLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue;
match expression.value {
pest::HexNumberExpression::U32(e) => {
absy::Expression::U32Constant(u32::from_str_radix(&e.span.as_str(), 16).unwrap())
}
pest::HexNumberExpression::U16(e) => {
absy::Expression::U16Constant(u16::from_str_radix(&e.span.as_str(), 16).unwrap())
}
pest::HexNumberExpression::U8(e) => {
absy::Expression::U8Constant(u8::from_str_radix(&e.span.as_str(), 16).unwrap())
}
}
.span(expression.span)
}
}
impl<'ast> From<pest::LiteralExpression<'ast>> for absy::ExpressionNode<'ast> {
fn from(expression: pest::LiteralExpression<'ast>) -> absy::ExpressionNode<'ast> {
use crate::absy::NodeValue;
match expression {
pest::ConstantExpression::BooleanLiteral(c) => {
pest::LiteralExpression::BooleanLiteral(c) => {
absy::Expression::BooleanConstant(c.value.parse().unwrap()).span(c.span)
}
pest::ConstantExpression::DecimalNumber(n) => absy::Expression::FieldConstant(
BigUint::parse_bytes(&n.value.as_bytes(), 10).unwrap(),
)
.span(n.span),
pest::ConstantExpression::U8(n) => absy::Expression::U8Constant(
u8::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
)
.span(n.span),
pest::ConstantExpression::U16(n) => absy::Expression::U16Constant(
u16::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
)
.span(n.span),
pest::ConstantExpression::U32(n) => absy::Expression::U32Constant(
u32::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(),
)
.span(n.span),
pest::LiteralExpression::DecimalLiteral(n) => absy::ExpressionNode::from(n),
pest::LiteralExpression::HexLiteral(n) => absy::ExpressionNode::from(n),
}
}
}
@ -611,8 +671,8 @@ impl<'ast> From<pest::Assignee<'ast>> for absy::AssigneeNode<'ast> {
}
}
impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode {
impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode<'ast> {
fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode<'ast> {
use crate::absy::types::UnresolvedType;
use crate::absy::NodeValue;
@ -642,21 +702,7 @@ impl<'ast> From<pest::Type<'ast>> for absy::UnresolvedTypeNode {
t.dimensions
.into_iter()
.map(|s| match s {
pest::Expression::Constant(c) => match c {
pest::ConstantExpression::DecimalNumber(n) => {
str::parse::<usize>(&n.value).unwrap()
}
_ => unimplemented!(
"Array size should be a decimal number, found {}",
c.span().as_str()
),
},
e => unimplemented!(
"Array size should be constant, found {}",
e.span().as_str()
),
})
.map(absy::ExpressionNode::from)
.rev()
.fold(None, |acc, s| match acc {
None => Some(UnresolvedType::array(inner_type.clone(), s)),
@ -690,10 +736,9 @@ mod tests {
arguments: vec![],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant(BigUint::from(
42u32,
))
.into()],
expressions: vec![
absy::Expression::IntConstant(42usize.into()).into()
],
}
.into(),
)
@ -771,10 +816,9 @@ mod tests {
],
statements: vec![absy::Statement::Return(
absy::ExpressionList {
expressions: vec![absy::Expression::FieldConstant(BigUint::from(
42u32,
))
.into()],
expressions: vec![
absy::Expression::IntConstant(42usize.into()).into()
],
}
.into(),
)
@ -800,7 +844,7 @@ mod tests {
use super::*;
/// Helper method to generate the ast for `def main(private {ty} a): return` which we use to check ty
fn wrap(ty: UnresolvedType) -> absy::Module<'static> {
fn wrap(ty: UnresolvedType<'static>) -> absy::Module<'static> {
absy::Module {
symbols: vec![absy::SymbolDeclaration {
id: "main",
@ -834,21 +878,31 @@ mod tests {
("bool", UnresolvedType::Boolean),
(
"field[2]",
UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 2),
),
(
"field[2][3]",
UnresolvedType::Array(
box UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 3)
.mock(),
2,
absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
(
"bool[2][3]",
UnresolvedType::Array(
box UnresolvedType::Array(box UnresolvedType::Boolean.mock(), 3).mock(),
2,
"field[2][3]",
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::FieldElement.mock(),
absy::Expression::IntConstant(3usize.into()).mock(),
)
.mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
(
"bool[2][3u32]",
absy::UnresolvedType::Array(
box absy::UnresolvedType::Array(
box absy::UnresolvedType::Boolean.mock(),
absy::Expression::U32Constant(3u32).mock(),
)
.mock(),
absy::Expression::IntConstant(2usize.into()).mock(),
),
),
];
@ -893,13 +947,13 @@ mod tests {
// we basically accept `()?[]*` : an optional call at first, then only array accesses
let vectors = vec![
("a", absy::Expression::Identifier("a").into()),
("a", absy::Expression::Identifier("a")),
(
"a[3]",
absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(3u32)).into(),
absy::Expression::IntConstant(3usize.into()).into(),
)
.into(),
),
@ -910,13 +964,13 @@ mod tests {
box absy::Expression::Select(
box absy::Expression::Identifier("a").into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(3u32)).into(),
absy::Expression::IntConstant(3usize.into()).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(4u32)).into(),
absy::Expression::IntConstant(4usize.into()).into(),
)
.into(),
),
@ -926,11 +980,12 @@ mod tests {
absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()],
None,
vec![absy::Expression::IntConstant(3usize.into()).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(4u32)).into(),
absy::Expression::IntConstant(4usize.into()).into(),
)
.into(),
),
@ -941,17 +996,18 @@ mod tests {
box absy::Expression::Select(
box absy::Expression::FunctionCall(
"a",
vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()],
None,
vec![absy::Expression::IntConstant(3usize.into()).into()],
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(4u32)).into(),
absy::Expression::IntConstant(4usize.into()).into(),
)
.into(),
)
.into(),
box absy::RangeOrExpression::Expression(
absy::Expression::FieldConstant(BigUint::from(5u32)).into(),
absy::Expression::IntConstant(5usize.into()).into(),
)
.into(),
),
@ -993,7 +1049,7 @@ mod tests {
// For different definitions, we generate declarations
// Case 1: `id = expr` where `expr` is not a function call
// This is a simple assignment, doesn't implicitely declare a variable
// A `Definition` is generatedm and no `Declaration`s
// A `Definition` is generated and no `Declaration`s
let definition = pest::DefinitionStatement {
lhs: vec![pest::OptionallyTypedAssignee {
@ -1008,9 +1064,12 @@ mod tests {
},
span: span.clone(),
}],
expression: pest::Expression::Constant(pest::ConstantExpression::DecimalNumber(
pest::DecimalNumberExpression {
value: String::from("42"),
expression: pest::Expression::Literal(pest::LiteralExpression::DecimalLiteral(
pest::DecimalLiteralExpression {
value: pest::DecimalNumber {
span: Span::new(&"1", 0, 1).unwrap(),
},
suffix: None,
span: span.clone(),
},
)),
@ -1049,7 +1108,11 @@ mod tests {
span: span.clone(),
},
accesses: vec![pest::Access::Call(pest::CallAccess {
expressions: vec![],
explicit_generics: None,
arguments: pest::Arguments {
expressions: vec![],
span: span.clone(),
},
span: span.clone(),
})],
span: span.clone(),
@ -1106,7 +1169,11 @@ mod tests {
span: span.clone(),
},
accesses: vec![pest::Access::Call(pest::CallAccess {
expressions: vec![],
explicit_generics: None,
arguments: pest::Arguments {
expressions: vec![],
span: span.clone(),
},
span: span.clone(),
})],
span: span.clone(),

View file

@ -105,7 +105,7 @@ impl<'ast> Module<'ast> {
}
}
pub type UnresolvedTypeNode = Node<UnresolvedType>;
pub type UnresolvedTypeNode<'ast> = Node<UnresolvedType<'ast>>;
/// A struct type definition
#[derive(Debug, Clone, PartialEq)]
@ -133,7 +133,7 @@ pub type StructDefinitionNode<'ast> = Node<StructDefinition<'ast>>;
#[derive(Debug, Clone, PartialEq)]
pub struct StructDefinitionField<'ast> {
pub id: Identifier<'ast>,
pub ty: UnresolvedTypeNode,
pub ty: UnresolvedTypeNode<'ast>,
}
impl<'ast> fmt::Display for StructDefinitionField<'ast> {
@ -216,6 +216,8 @@ impl<'ast> fmt::Debug for Module<'ast> {
}
}
pub type ConstantGenericNode<'ast> = Node<Identifier<'ast>>;
/// A function defined locally
#[derive(Clone, PartialEq)]
pub struct Function<'ast> {
@ -224,13 +226,26 @@ pub struct Function<'ast> {
/// Vector of statements that are executed when running the function
pub statements: Vec<StatementNode<'ast>>,
/// function signature
pub signature: UnresolvedSignature,
pub signature: UnresolvedSignature<'ast>,
}
pub type FunctionNode<'ast> = Node<Function<'ast>>;
impl<'ast> fmt::Display for Function<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.signature.generics.is_empty() {
write!(
f,
"<{}>",
self.signature
.generics
.iter()
.map(|g| g.to_string())
.collect::<Vec<_>>()
.join(", ")
)?;
}
write!(
f,
"({}):\n{}",
@ -294,6 +309,7 @@ impl<'ast> fmt::Display for Assignee<'ast> {
}
/// A statement in a `Function`
#[allow(clippy::large_enum_variant)]
#[derive(Clone, PartialEq)]
pub enum Statement<'ast> {
Return(ExpressionListNode<'ast>),
@ -319,9 +335,9 @@ impl<'ast> fmt::Display for Statement<'ast> {
Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs),
Statement::Assertion(ref e) => write!(f, "assert({})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {} in {}..{} do\n", var, start, stop)?;
writeln!(f, "for {} in {}..{} do", var, start, stop)?;
for l in list {
write!(f, "\t\t{}\n", l)?;
writeln!(f, "\t\t{}", l)?;
}
write!(f, "\tendfor")
}
@ -348,9 +364,9 @@ impl<'ast> fmt::Debug for Statement<'ast> {
}
Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e),
Statement::For(ref var, ref start, ref stop, ref list) => {
write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?;
writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?;
for l in list {
write!(f, "\t\t{:?}\n", l)?;
writeln!(f, "\t\t{:?}", l)?;
}
write!(f, "\tendfor")
}
@ -454,11 +470,11 @@ impl<'ast> fmt::Display for Range<'ast> {
self.from
.as_ref()
.map(|e| e.to_string())
.unwrap_or("".to_string()),
.unwrap_or_else(|| "".to_string()),
self.to
.as_ref()
.map(|e| e.to_string())
.unwrap_or("".to_string())
.unwrap_or_else(|| "".to_string())
)
}
}
@ -472,6 +488,7 @@ impl<'ast> fmt::Debug for Range<'ast> {
/// An expression
#[derive(Clone, PartialEq)]
pub enum Expression<'ast> {
IntConstant(BigUint),
FieldConstant(BigUint),
BooleanConstant(bool),
U8Constant(u8),
@ -491,7 +508,11 @@ pub enum Expression<'ast> {
Box<ExpressionNode<'ast>>,
Box<ExpressionNode<'ast>>,
),
FunctionCall(FunctionIdentifier<'ast>, Vec<ExpressionNode<'ast>>),
FunctionCall(
FunctionIdentifier<'ast>,
Option<Vec<Option<ExpressionNode<'ast>>>>,
Vec<ExpressionNode<'ast>>,
),
Lt(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Le(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Eq(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
@ -500,6 +521,7 @@ pub enum Expression<'ast> {
And(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
Not(Box<ExpressionNode<'ast>>),
InlineArray(Vec<SpreadOrExpression<'ast>>),
ArrayInitializer(Box<ExpressionNode<'ast>>, Box<ExpressionNode<'ast>>),
InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast>)>),
Select(Box<ExpressionNode<'ast>>, Box<RangeOrExpression<'ast>>),
Member(Box<ExpressionNode<'ast>>, Box<Identifier<'ast>>),
@ -520,6 +542,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
Expression::U8Constant(ref i) => write!(f, "{}", i),
Expression::U16Constant(ref i) => write!(f, "{}", i),
Expression::U32Constant(ref i) => write!(f, "{}", i),
Expression::IntConstant(ref i) => write!(f, "{}", i),
Expression::Identifier(ref var) => write!(f, "{}", var),
Expression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs),
Expression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs),
@ -535,8 +558,21 @@ impl<'ast> fmt::Display for Expression<'ast> {
"if {} then {} else {} fi",
condition, consequent, alternative
),
Expression::FunctionCall(ref i, ref p) => {
write!(f, "{}(", i,)?;
Expression::FunctionCall(ref i, ref g, ref p) => {
if let Some(g) = g {
write!(
f,
"::<{}>",
g.iter()
.map(|g| g
.as_ref()
.map(|g| g.to_string())
.unwrap_or_else(|| "_".into()))
.collect::<Vec<_>>()
.join(", "),
)?;
}
write!(f, "{}(", i)?;
for (i, param) in p.iter().enumerate() {
write!(f, "{}", param)?;
if i < p.len() - 1 {
@ -562,6 +598,7 @@ impl<'ast> fmt::Display for Expression<'ast> {
}
write!(f, "]")
}
Expression::ArrayInitializer(ref e, ref count) => write!(f, "[{}; {}]", e, count),
Expression::InlineStruct(ref id, ref members) => {
write!(f, "{} {{", id)?;
for (i, (member_id, e)) in members.iter().enumerate() {
@ -587,10 +624,11 @@ impl<'ast> fmt::Display for Expression<'ast> {
impl<'ast> fmt::Debug for Expression<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Expression::U8Constant(ref i) => write!(f, "{:x}", i),
Expression::U16Constant(ref i) => write!(f, "{:x}", i),
Expression::U32Constant(ref i) => write!(f, "{:x}", i),
Expression::FieldConstant(ref i) => write!(f, "Num({:?})", i),
Expression::U8Constant(ref i) => write!(f, "U8({:x})", i),
Expression::U16Constant(ref i) => write!(f, "U16({:x})", i),
Expression::U32Constant(ref i) => write!(f, "U32({:x})", i),
Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i),
Expression::IntConstant(ref i) => write!(f, "Int({:?})", i),
Expression::Identifier(ref var) => write!(f, "Ide({})", var),
Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs),
Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
@ -606,8 +644,8 @@ impl<'ast> fmt::Debug for Expression<'ast> {
"IfElse({:?}, {:?}, {:?})",
condition, consequent, alternative
),
Expression::FunctionCall(ref i, ref p) => {
write!(f, "FunctionCall({:?}, (", i)?;
Expression::FunctionCall(ref g, ref i, ref p) => {
write!(f, "FunctionCall({:?}, {:?}, (", g, i)?;
f.debug_list().entries(p.iter()).finish()?;
write!(f, ")")
}
@ -623,6 +661,9 @@ impl<'ast> fmt::Debug for Expression<'ast> {
f.debug_list().entries(exprs.iter()).finish()?;
write!(f, "]")
}
Expression::ArrayInitializer(ref e, ref count) => {
write!(f, "ArrayInitializer({:?}, {:?})", e, count)
}
Expression::InlineStruct(ref id, ref members) => {
write!(f, "InlineStruct({:?}, [", id)?;
f.debug_list().entries(members.iter()).finish()?;
@ -645,21 +686,13 @@ impl<'ast> fmt::Debug for Expression<'ast> {
}
/// A list of expressions, used in return statements
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Default)]
pub struct ExpressionList<'ast> {
pub expressions: Vec<ExpressionNode<'ast>>,
}
pub type ExpressionListNode<'ast> = Node<ExpressionList<'ast>>;
impl<'ast> ExpressionList<'ast> {
pub fn new() -> ExpressionList<'ast> {
ExpressionList {
expressions: vec![],
}
}
}
impl<'ast> fmt::Display for ExpressionList<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for (i, param) in self.expressions.iter().enumerate() {

View file

@ -81,7 +81,7 @@ impl<'ast> NodeValue for ExpressionList<'ast> {}
impl<'ast> NodeValue for Assignee<'ast> {}
impl<'ast> NodeValue for Statement<'ast> {}
impl<'ast> NodeValue for SymbolDeclaration<'ast> {}
impl NodeValue for UnresolvedType {}
impl<'ast> NodeValue for UnresolvedType<'ast> {}
impl<'ast> NodeValue for StructDefinition<'ast> {}
impl<'ast> NodeValue for StructDefinitionField<'ast> {}
impl<'ast> NodeValue for Function<'ast> {}
@ -92,6 +92,7 @@ impl<'ast> NodeValue for Parameter<'ast> {}
impl<'ast> NodeValue for Import<'ast> {}
impl<'ast> NodeValue for Spread<'ast> {}
impl<'ast> NodeValue for Range<'ast> {}
impl<'ast> NodeValue for Identifier<'ast> {}
impl<T: PartialEq> PartialEq for Node<T> {
fn eq(&self, other: &Node<T>) -> bool {

View file

@ -1,3 +1,4 @@
use crate::absy::ExpressionNode;
use crate::absy::UnresolvedTypeNode;
use std::fmt;
@ -8,15 +9,15 @@ pub type MemberId = String;
pub type UserTypeId = String;
#[derive(Clone, PartialEq, Debug)]
pub enum UnresolvedType {
pub enum UnresolvedType<'ast> {
FieldElement,
Boolean,
Uint(usize),
Array(Box<UnresolvedTypeNode>, usize),
Array(Box<UnresolvedTypeNode<'ast>>, ExpressionNode<'ast>),
User(UserTypeId),
}
impl fmt::Display for UnresolvedType {
impl<'ast> fmt::Display for UnresolvedType<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
UnresolvedType::FieldElement => write!(f, "field"),
@ -28,8 +29,8 @@ impl fmt::Display for UnresolvedType {
}
}
impl UnresolvedType {
pub fn array(ty: UnresolvedTypeNode, size: usize) -> Self {
impl<'ast> UnresolvedType<'ast> {
pub fn array(ty: UnresolvedTypeNode<'ast>, size: ExpressionNode<'ast>) -> Self {
UnresolvedType::Array(box ty, size)
}
}
@ -39,17 +40,19 @@ pub type FunctionIdentifier<'ast> = &'ast str;
pub use self::signature::UnresolvedSignature;
mod signature {
use crate::absy::ConstantGenericNode;
use std::fmt;
use crate::absy::UnresolvedTypeNode;
#[derive(Clone, PartialEq)]
pub struct UnresolvedSignature {
pub inputs: Vec<UnresolvedTypeNode>,
pub outputs: Vec<UnresolvedTypeNode>,
#[derive(Clone, PartialEq, Default)]
pub struct UnresolvedSignature<'ast> {
pub generics: Vec<ConstantGenericNode<'ast>>,
pub inputs: Vec<UnresolvedTypeNode<'ast>>,
pub outputs: Vec<UnresolvedTypeNode<'ast>>,
}
impl fmt::Debug for UnresolvedSignature {
impl<'ast> fmt::Debug for UnresolvedSignature<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
@ -59,7 +62,7 @@ mod signature {
}
}
impl fmt::Display for UnresolvedSignature {
impl<'ast> fmt::Display for UnresolvedSignature<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(")?;
for (i, t) in self.inputs.iter().enumerate() {
@ -79,20 +82,22 @@ mod signature {
}
}
impl UnresolvedSignature {
pub fn new() -> UnresolvedSignature {
UnresolvedSignature {
inputs: vec![],
outputs: vec![],
}
impl<'ast> UnresolvedSignature<'ast> {
pub fn new() -> UnresolvedSignature<'ast> {
UnresolvedSignature::default()
}
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode>) -> Self {
pub fn generics(mut self, generics: Vec<ConstantGenericNode<'ast>>) -> Self {
self.generics = generics;
self
}
pub fn inputs(mut self, inputs: Vec<UnresolvedTypeNode<'ast>>) -> Self {
self.inputs = inputs;
self
}
pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode>) -> Self {
pub fn outputs(mut self, outputs: Vec<UnresolvedTypeNode<'ast>>) -> Self {
self.outputs = outputs;
self
}

View file

@ -7,21 +7,21 @@ use crate::absy::Identifier;
#[derive(Clone, PartialEq)]
pub struct Variable<'ast> {
pub id: Identifier<'ast>,
pub _type: UnresolvedTypeNode,
pub _type: UnresolvedTypeNode<'ast>,
}
pub type VariableNode<'ast> = Node<Variable<'ast>>;
impl<'ast> Variable<'ast> {
pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode) -> Variable<'ast> {
pub fn new<S: Into<&'ast str>>(id: S, t: UnresolvedTypeNode<'ast>) -> Variable<'ast> {
Variable {
id: id.into(),
_type: t,
}
}
pub fn get_type(&self) -> UnresolvedType {
self._type.value.clone()
pub fn get_type(&self) -> &UnresolvedType<'ast> {
&self._type.value
}
}

View file

@ -9,6 +9,7 @@ use crate::imports::{self, Importer};
use crate::ir;
use crate::macros;
use crate::semantics::{self, Checker};
use crate::static_analysis;
use crate::static_analysis::Analyse;
use crate::typed_absy::abi::Abi;
use crate::zir::ZirProgram;
@ -55,6 +56,7 @@ pub enum CompileErrorInner {
MacroError(macros::Error),
SemanticError(semantics::ErrorInner),
ReadError(io::Error),
AnalysisError(static_analysis::Error),
}
impl CompileErrorInner {
@ -129,6 +131,12 @@ impl From<semantics::Error> for CompileError {
}
}
impl From<static_analysis::Error> for CompileErrorInner {
fn from(error: static_analysis::Error) -> Self {
CompileErrorInner::AnalysisError(error)
}
}
impl fmt::Display for CompileErrorInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -137,6 +145,7 @@ impl fmt::Display for CompileErrorInner {
CompileErrorInner::SemanticError(ref e) => write!(f, "{}", e),
CompileErrorInner::ReadError(ref e) => write!(f, "{}", e),
CompileErrorInner::ImportError(ref e) => write!(f, "{}", e),
CompileErrorInner::AnalysisError(ref e) => write!(f, "{}", e),
}
}
}
@ -179,7 +188,7 @@ pub fn compile<T: Field, E: Into<imports::Error>>(
})
}
pub fn check<'ast, T: Field, E: Into<imports::Error>>(
pub fn check<T: Field, E: Into<imports::Error>>(
source: String,
location: FilePath,
resolver: Option<&dyn Resolver<E>>,
@ -196,19 +205,18 @@ fn check_with_arena<'ast, T: Field, E: Into<imports::Error>>(
arena: &'ast Arena<String>,
) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> {
let source = arena.alloc(source);
let compiled = compile_program::<T, E>(source, location.clone(), resolver, &arena)?;
let compiled = compile_program::<T, E>(source, location, resolver, &arena)?;
// check semantics
let typed_ast = Checker::check(compiled).map_err(|errors| {
CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect())
})?;
let typed_ast = Checker::check(compiled)
.map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?;
let abi = typed_ast.abi();
let main_module = typed_ast.main.clone();
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
Ok((typed_ast, abi))
typed_ast
.analyse()
.map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)]))
}
pub fn compile_program<'ast, T: Field, E: Into<imports::Error>>(
@ -244,7 +252,7 @@ pub fn compile_module<'ast, T: Field, E: Into<imports::Error>>(
let module_without_imports: Module = Module::from(ast);
Importer::new().apply_imports::<T, E>(
Importer::apply_imports::<T, E>(
module_without_imports,
location.clone(),
resolver,
@ -380,17 +388,17 @@ struct Bar { field a }
inputs: vec![AbiInput {
name: "f".into(),
public: true,
ty: Type::Struct(StructType::new(
ty: ConcreteType::Struct(ConcreteStructType::new(
"foo".into(),
"Foo".into(),
vec![StructMember {
vec![ConcreteStructMember {
id: "b".into(),
ty: box Type::Struct(StructType::new(
ty: box ConcreteType::Struct(ConcreteStructType::new(
"bar".into(),
"Bar".into(),
vec![StructMember {
vec![ConcreteStructMember {
id: "a".into(),
ty: box Type::FieldElement
ty: box ConcreteType::FieldElement
}]
))
}]

View file

@ -3,7 +3,9 @@ use crate::flat_absy::{
FlatVariable,
};
use crate::solvers::Solver;
use crate::typed_absy::types::{FunctionKey, Signature, Type};
use crate::typed_absy::types::{
ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType,
};
use std::collections::HashMap;
use zokrates_field::{Bn128Field, Field};
@ -18,11 +20,12 @@ cfg_if::cfg_if! {
/// A low level function that contains non-deterministic introduction of variables. It is carried out as is until
/// the flattening step when it can be inlined.
#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub enum FlatEmbed {
U32ToField,
#[cfg(feature = "bellman")]
Sha256Round,
Unpack(usize),
Unpack,
U8ToBits,
U16ToBits,
U32ToBits,
@ -32,48 +35,88 @@ pub enum FlatEmbed {
}
impl FlatEmbed {
pub fn signature(&self) -> Signature {
pub fn signature(&self) -> DeclarationSignature<'static> {
match self {
FlatEmbed::U32ToField => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::FieldElement]),
FlatEmbed::Unpack => DeclarationSignature::new()
.generics(vec![Some(Constant::Generic("N"))])
.inputs(vec![DeclarationType::FieldElement])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
"N",
))]),
FlatEmbed::U8ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(8)])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
8usize,
))]),
FlatEmbed::U16ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(16)])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
16usize,
))]),
FlatEmbed::U32ToBits => DeclarationSignature::new()
.inputs(vec![DeclarationType::uint(32)])
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
32usize,
))]),
FlatEmbed::U8FromBits => DeclarationSignature::new()
.outputs(vec![DeclarationType::uint(8)])
.inputs(vec![DeclarationType::array((
DeclarationType::Boolean,
8usize,
))]),
FlatEmbed::U16FromBits => DeclarationSignature::new()
.outputs(vec![DeclarationType::uint(16)])
.inputs(vec![DeclarationType::array((
DeclarationType::Boolean,
16usize,
))]),
FlatEmbed::U32FromBits => DeclarationSignature::new()
.outputs(vec![DeclarationType::uint(32)])
.inputs(vec![DeclarationType::array((
DeclarationType::Boolean,
32usize,
))]),
#[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => Signature::new()
FlatEmbed::Sha256Round => DeclarationSignature::new()
.inputs(vec![
Type::array(Type::Boolean, 512),
Type::array(Type::Boolean, 256),
DeclarationType::array((DeclarationType::Boolean, 512usize)),
DeclarationType::array((DeclarationType::Boolean, 256usize)),
])
.outputs(vec![Type::array(Type::Boolean, 256)]),
FlatEmbed::Unpack(bitwidth) => Signature::new()
.inputs(vec![Type::FieldElement])
.outputs(vec![Type::array(Type::Boolean, *bitwidth)]),
FlatEmbed::U8ToBits => Signature::new()
.inputs(vec![Type::uint(8)])
.outputs(vec![Type::array(Type::Boolean, 8)]),
FlatEmbed::U16ToBits => Signature::new()
.inputs(vec![Type::uint(16)])
.outputs(vec![Type::array(Type::Boolean, 16)]),
FlatEmbed::U32ToBits => Signature::new()
.inputs(vec![Type::uint(32)])
.outputs(vec![Type::array(Type::Boolean, 32)]),
FlatEmbed::U8FromBits => Signature::new()
.outputs(vec![Type::uint(8)])
.inputs(vec![Type::array(Type::Boolean, 8)]),
FlatEmbed::U16FromBits => Signature::new()
.outputs(vec![Type::uint(16)])
.inputs(vec![Type::array(Type::Boolean, 16)]),
FlatEmbed::U32FromBits => Signature::new()
.outputs(vec![Type::uint(32)])
.inputs(vec![Type::array(Type::Boolean, 32)]),
.outputs(vec![DeclarationType::array((
DeclarationType::Boolean,
256usize,
))]),
}
}
pub fn key<T: Field>(&self) -> FunctionKey<'static> {
FunctionKey::with_id(self.id()).signature(self.signature())
pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec<u32> {
let gen = self
.signature()
.generics
.into_iter()
.map(|c| match c.unwrap() {
Constant::Generic(g) => g,
_ => unreachable!(),
});
assert_eq!(gen.len(), assignment.0.len());
gen.map(|g| *assignment.0.get(&g).clone().unwrap() as u32)
.collect()
}
pub fn id(&self) -> &'static str {
match self {
FlatEmbed::U32ToField => "_U32_TO_FIELD",
#[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => "_SHA256_ROUND",
FlatEmbed::Unpack(_) => "_UNPACK",
FlatEmbed::Unpack => "_UNPACK",
FlatEmbed::U8ToBits => "_U8_TO_BITS",
FlatEmbed::U16ToBits => "_U16_TO_BITS",
FlatEmbed::U32ToBits => "_U32_TO_BITS",
@ -84,11 +127,11 @@ impl FlatEmbed {
}
/// Actually get the `FlatFunction` that this `FlatEmbed` represents
pub fn synthetize<T: Field>(&self) -> FlatFunction<T> {
pub fn synthetize<T: Field>(&self, generics: &[u32]) -> FlatFunction<T> {
match self {
#[cfg(feature = "bellman")]
FlatEmbed::Sha256Round => sha256_round(),
FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth),
FlatEmbed::Unpack => unpack_to_bitwidth(generics[0] as usize),
_ => unreachable!(),
}
}
@ -101,7 +144,7 @@ fn flat_expression_from_vec<T: Field, E: Engine>(v: &[(usize, E::Fr)]) -> FlatEx
match v.len() {
0 => FlatExpression::Number(T::zero()),
1 => {
let (key, val) = v[0].clone();
let (key, val) = v[0];
let mut res: Vec<u8> = vec![];
val.into_repr().write_le(&mut res).unwrap();
FlatExpression::Mult(
@ -152,7 +195,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
let output_indices = output_indices.into_iter();
let variable_count = r1cs.aux_count + 1; // auxiliary and ONE
// indices of the sha256round constraint system variables
let cs_indices = (0..variable_count).into_iter();
let cs_indices = 0..variable_count;
// indices of the arguments to the function
// apply an offset of `variable_count` to get the indice of our dummy `input` argument
let input_argument_indices = input_indices
@ -180,7 +223,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
);
let input_binding_statements =
// bind input and current_hash to inputs
input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
input_indices.chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| {
FlatStatement::Condition(
FlatVariable::new(cs_index).into(),
FlatVariable::new(argument_index).into(),
@ -194,7 +237,7 @@ pub fn sha256_round<T: Field>() -> FlatFunction<T> {
.collect();
// insert a directive to set the witness based on the bellman gadget and inputs
let directive_statement = FlatStatement::Directive(FlatDirective {
outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(),
outputs: cs_indices.map(FlatVariable::new).collect(),
inputs: input_argument_indices
.chain(current_hash_argument_indices)
.map(|i| FlatVariable::new(i).into())
@ -224,7 +267,7 @@ fn use_variable(
) -> FlatVariable {
let var = FlatVariable::new(*index);
layout.insert(name, var);
*index = *index + 1;
*index += 1;
var
}
@ -255,7 +298,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
let directive_inputs = vec![FlatExpression::Identifier(use_variable(
&mut layout,
format!("i0"),
"i0".into(),
&mut counter,
))];
@ -268,7 +311,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
let outputs = directive_outputs
.iter()
.enumerate()
.map(|(_, o)| FlatExpression::Identifier(o.clone()))
.map(|(_, o)| FlatExpression::Identifier(*o))
.collect::<Vec<_>>();
// o253, o252, ... o{253 - (bit_width - 1)} are bits
@ -308,7 +351,7 @@ pub fn unpack_to_bitwidth<T: Field>(bit_width: usize) -> FlatFunction<T> {
FlatStatement::Directive(FlatDirective {
inputs: directive_inputs,
outputs: directive_outputs,
solver: solver,
solver,
}),
);
@ -436,7 +479,7 @@ mod tests {
private: vec![true; 768],
};
let input = (0..512)
let input: Vec<_> = (0..512)
.map(|_| 0)
.chain((0..256).map(|_| 1))
.map(|i| Bn128Field::from(i))

View file

@ -41,7 +41,7 @@ impl FlatParameter {
substitution: &HashMap<FlatVariable, FlatVariable>,
) -> FlatParameter {
FlatParameter {
id: substitution.get(&self.id).unwrap().clone(),
id: *substitution.get(&self.id).unwrap(),
private: self.private,
}
}

View file

@ -45,7 +45,7 @@ impl FlatVariable {
Ok(FlatVariable::public(v))
}
None => {
let mut private = s.split("_");
let mut private = s.split('_');
match private.nth(1) {
Some(v) => {
let v = v.parse().map_err(|_| s)?;

View file

@ -155,23 +155,13 @@ impl<T: Field> FlatStatement<T> {
}
}
#[derive(Clone, Hash, Debug)]
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
pub struct FlatDirective<T: Field> {
pub inputs: Vec<FlatExpression<T>>,
pub outputs: Vec<FlatVariable>,
pub solver: Solver,
}
impl<T: Field> PartialEq for FlatDirective<T> {
fn eq(&self, other: &Self) -> bool {
self.inputs.eq(&other.inputs)
&& self.outputs.eq(&other.outputs)
&& self.solver.eq(&other.solver)
}
}
impl<T: Field> Eq for FlatDirective<T> {}
impl<T: Field> FlatDirective<T> {
pub fn new<E: Into<FlatExpression<T>>>(
outputs: Vec<FlatVariable>,
@ -249,12 +239,18 @@ impl<T: Field> FlatExpression<T> {
FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => {
x.is_linear() && y.is_linear()
}
FlatExpression::Mult(ref x, ref y) => match (x.clone(), y.clone()) {
FlatExpression::Mult(ref x, ref y) => matches!(
(x.clone(), y.clone()),
(box FlatExpression::Number(_), box FlatExpression::Number(_))
| (box FlatExpression::Number(_), box FlatExpression::Identifier(_))
| (box FlatExpression::Identifier(_), box FlatExpression::Number(_)) => true,
_ => false,
},
| (
box FlatExpression::Number(_),
box FlatExpression::Identifier(_)
)
| (
box FlatExpression::Identifier(_),
box FlatExpression::Number(_)
)
),
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -42,7 +42,7 @@ impl fmt::Display for Error {
let location = self
.pos
.map(|p| format!("{}", p.0))
.unwrap_or("?".to_string());
.unwrap_or_else(|| "?".to_string());
write!(f, "{}\n\t{}", location, self.message)
}
}
@ -125,15 +125,10 @@ impl<'ast> fmt::Debug for Import<'ast> {
}
}
pub struct Importer {}
pub struct Importer;
impl Importer {
pub fn new() -> Importer {
Importer {}
}
pub fn apply_imports<'ast, T: Field, E: Into<Error>>(
&self,
destination: Module<'ast>,
location: PathBuf,
resolver: Option<&dyn Resolver<E>>,
@ -179,7 +174,7 @@ impl Importer {
symbols.push(
SymbolDeclaration {
id: &alias,
symbol: Symbol::Flat(FlatEmbed::Unpack(T::get_required_bits())),
symbol: Symbol::Flat(FlatEmbed::Unpack),
}
.start_end(pos.0, pos.1),
);
@ -267,13 +262,15 @@ impl Importer {
let alias = import.alias.unwrap_or(
std::path::Path::new(import.source)
.file_stem()
.ok_or(CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
import.source.display()
)))
.in_file(&location),
))?
.ok_or_else(|| {
CompileErrors::from(
CompileErrorInner::ImportError(Error::new(format!(
"Could not determine alias for import {}",
import.source.display()
)))
.in_file(&location),
)
})?
.to_str()
.unwrap(),
);
@ -335,7 +332,6 @@ impl Importer {
Ok(Module {
imports: vec![],
symbols,
..destination
})
}
}

View file

@ -2,49 +2,36 @@ use crate::flat_absy::FlatVariable;
use serde::{Deserialize, Serialize};
use std::collections::btree_map::{BTreeMap, Entry};
use std::fmt;
use std::hash::Hash;
use std::ops::{Add, Div, Mul, Sub};
use zokrates_field::Field;
#[derive(Debug, Clone, Serialize, Deserialize, Hash)]
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct QuadComb<T> {
pub left: LinComb<T>,
pub right: LinComb<T>,
}
impl<T: Field> PartialEq for QuadComb<T> {
fn eq(&self, other: &Self) -> bool {
self.left.eq(&other.left) && self.right.eq(&other.right)
}
}
impl<T: Field> Eq for QuadComb<T> {}
impl<T: Field> QuadComb<T> {
pub fn from_linear_combinations(left: LinComb<T>, right: LinComb<T>) -> Self {
QuadComb { left, right }
}
pub fn try_linear(&self) -> Option<LinComb<T>> {
// identify (k * ~ONE) * (lincomb) and return (k * lincomb)
match self.left.try_summand() {
Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => {
return Some(self.right.clone() * &coefficient);
}
_ => {}
}
match self.right.try_summand() {
Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => {
return Some(self.left.clone() * &coefficient);
}
_ => {}
}
pub fn try_linear(self) -> Result<LinComb<T>, Self> {
// identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb)
// if not, error out with the input
if self.left.is_zero() || self.right.is_zero() {
return Some(LinComb::zero());
return Ok(LinComb::zero());
}
None
match self.left.try_constant() {
Ok(coefficient) => Ok(self.right * &coefficient),
Err(left) => match self.right.try_constant() {
Ok(coefficient) => Ok(left * &coefficient),
Err(right) => Err(QuadComb::from_linear_combinations(left, right)),
},
}
}
}
@ -66,17 +53,9 @@ impl<T: Field> fmt::Display for QuadComb<T> {
}
}
#[derive(Clone, Hash, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct LinComb<T>(pub Vec<(FlatVariable, T)>);
impl<T: Field> PartialEq for LinComb<T> {
fn eq(&self, other: &Self) -> bool {
self.clone().into_canonical() == other.clone().into_canonical()
}
}
impl<T: Field> Eq for LinComb<T> {}
#[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct CanonicalLinComb<T>(pub BTreeMap<FlatVariable, T>);
@ -113,36 +92,52 @@ impl<T> LinComb<T> {
}
pub fn is_zero(&self) -> bool {
self.0.len() == 0
self.0.is_empty()
}
}
impl<T: Field> LinComb<T> {
pub fn try_summand(&self) -> Option<(FlatVariable, T)> {
pub fn try_constant(self) -> Result<T, Self> {
match self.0.len() {
// if the lincomb is empty, it is not reduceable to a summand
0 => None,
// if the lincomb is empty, it is reduceable to 0
0 => Ok(T::zero()),
_ => {
// take the first variable in the lincomb
let first = &self.0[0].0;
self.0
.iter()
.map(|element| {
if first != &FlatVariable::one() {
return Err(self);
}
// all terms must contain the same variable
if self.0.iter().all(|element| element.0 == *first) {
Ok(self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1))
} else {
Err(self)
}
}
}
}
pub fn try_summand(self) -> Result<(FlatVariable, T), Self> {
match self.0.len() {
// if the lincomb is empty, it is not reduceable to a summand
0 => Err(self),
_ => {
// take the first variable in the lincomb
let first = &self.0[0].0;
if self.0.iter().all(|element|
// all terms must contain the same variable
if element.0 == *first {
// if they do, return the coefficient
Ok(&element.1)
} else {
// otherwise, stop
Err(())
}
})
// collect to a Result to short circuit when we hit an error
.collect::<Result<_, _>>()
// we didn't hit an error, do final processing. It's fine to clone here.
.map(|v: Vec<_>| (first.clone(), v.iter().fold(T::zero(), |acc, e| acc + *e)))
.ok()
element.0 == *first)
{
Ok((
*first,
self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1),
))
} else {
Err(self)
}
}
}
}
@ -207,9 +202,7 @@ impl<T: Field> fmt::Display for LinComb<T> {
false => write!(
f,
"{}",
self.clone()
.into_canonical()
.0
self.0
.iter()
.map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k))
.collect::<Vec<_>>()
@ -251,10 +244,14 @@ impl<T: Field> Mul<&T> for LinComb<T> {
type Output = LinComb<T>;
fn mul(self, scalar: &T) -> LinComb<T> {
if scalar == &T::one() {
return self;
}
LinComb(
self.0
.into_iter()
.map(|(var, coeff)| (var, coeff * scalar))
.map(|(var, coeff)| (var, coeff * scalar.clone()))
.collect(),
)
}
@ -262,7 +259,8 @@ impl<T: Field> Mul<&T> for LinComb<T> {
impl<T: Field> Div<&T> for LinComb<T> {
type Output = LinComb<T>;
// Clippy warns about multiplication in a method named div. It's okay, here, since we multiply with the inverse.
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, scalar: &T) -> LinComb<T> {
self * &scalar.inverse_mul().unwrap()
}
@ -287,7 +285,7 @@ mod tests {
fn add() {
let a: LinComb<Bn128Field> = FlatVariable::new(42).into();
let b: LinComb<Bn128Field> = FlatVariable::new(42).into();
let c = a + b.clone();
let c = a + b;
let expected_vec = vec![
(FlatVariable::new(42), Bn128Field::from(1)),
@ -300,7 +298,7 @@ mod tests {
fn sub() {
let a: LinComb<Bn128Field> = FlatVariable::new(42).into();
let b: LinComb<Bn128Field> = FlatVariable::new(42).into();
let c = a - b.clone();
let c = a - b;
let expected_vec = vec![
(FlatVariable::new(42), Bn128Field::from(1)),
@ -314,7 +312,7 @@ mod tests {
fn display() {
let a: LinComb<Bn128Field> =
LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21));
assert_eq!(&a.to_string(), "3 * _21 + 1 * _42");
assert_eq!(&a.to_string(), "1 * _42 + 3 * _21");
let zero: LinComb<Bn128Field> = LinComb::zero();
assert_eq!(&zero.to_string(), "0");
}
@ -350,7 +348,7 @@ mod tests {
+ LinComb::summand(4, FlatVariable::new(33)),
right: LinComb::summand(1, FlatVariable::new(21)),
};
assert_eq!(&a.to_string(), "(4 * _33 + 3 * _42) * (1 * _21)");
assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)");
let a: QuadComb<Bn128Field> = QuadComb {
left: LinComb::zero(),
right: LinComb::summand(1, FlatVariable::new(21)),
@ -371,7 +369,7 @@ mod tests {
]);
assert_eq!(
summand.try_summand(),
Some((FlatVariable::new(42), Bn128Field::from(6)))
Ok((FlatVariable::new(42), Bn128Field::from(6)))
);
let not_summand = LinComb(vec![
@ -379,10 +377,10 @@ mod tests {
(FlatVariable::new(42), Bn128Field::from(2)),
(FlatVariable::new(42), Bn128Field::from(3)),
]);
assert_eq!(not_summand.try_summand(), None);
assert!(not_summand.try_summand().is_err());
let empty: LinComb<Bn128Field> = LinComb(vec![]);
assert_eq!(empty.try_summand(), None);
assert!(empty.try_summand().is_err());
}
}
}

View file

@ -125,7 +125,7 @@ impl<T: Field> From<FlatDirective<T>> for Directive<T> {
inputs: ds
.inputs
.into_iter()
.map(|i| QuadComb::from_flat_expression(i))
.map(QuadComb::from_flat_expression)
.collect(),
solver: ds.solver,
outputs: ds.outputs,

View file

@ -1,10 +1,13 @@
use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::Directive;
use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness};
use crate::solvers::{Executable, Solver};
use crate::solvers::Solver;
use pairing_ce::bn256::Bn256;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt;
#[cfg(feature = "bellman")]
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::Field;
pub type ExecutionResult<T> = Result<Witness<T>, Error>;
@ -34,13 +37,13 @@ impl Interpreter {
}
impl Interpreter {
pub fn execute<T: Field>(&self, program: &Prog<T>, inputs: &Vec<T>) -> ExecutionResult<T> {
pub fn execute<T: Field>(&self, program: &Prog<T>, inputs: &[T]) -> ExecutionResult<T> {
let main = &program.main;
self.check_inputs(&program, &inputs)?;
let mut witness = BTreeMap::new();
witness.insert(FlatVariable::one(), T::one());
for (arg, value) in main.arguments.iter().zip(inputs.iter()) {
witness.insert(arg.clone(), value.clone().into());
witness.insert(*arg, value.clone());
}
for statement in main.statements.iter() {
@ -48,7 +51,7 @@ impl Interpreter {
Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) {
true => {
let val = quad.evaluate(&witness).unwrap();
witness.insert(lin.0.iter().next().unwrap().0.clone(), val);
witness.insert(lin.0.get(0).unwrap().0, val);
}
false => {
let lhs_value = quad.evaluate(&witness).unwrap();
@ -76,10 +79,10 @@ impl Interpreter {
.iter()
.map(|i| i.evaluate(&witness).unwrap())
.collect();
match d.solver.execute(&inputs) {
match self.execute_solver(&d.solver, &inputs) {
Ok(res) => {
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(o.clone(), res[i].clone());
witness.insert(*o, res[i].clone());
}
continue;
}
@ -107,12 +110,12 @@ impl Interpreter {
value.to_biguint()
};
let mut num = input.clone();
let mut num = input;
let mut res = vec![];
let bits = T::get_required_bits();
for i in (0..bits).rev() {
if T::from(2).to_biguint().pow(i as usize) <= num {
num = num - T::from(2).to_biguint().pow(i as usize);
num -= T::from(2).to_biguint().pow(i as usize);
res.push(T::one());
} else {
res.push(T::zero());
@ -120,11 +123,11 @@ impl Interpreter {
}
assert_eq!(num, T::zero().to_biguint());
for (i, o) in d.outputs.iter().enumerate() {
witness.insert(o.clone(), res[i].clone());
witness.insert(*o, res[i].clone());
}
}
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &Vec<U>) -> Result<(), Error> {
fn check_inputs<T: Field, U>(&self, program: &Prog<T>, inputs: &[U]) -> Result<(), Error> {
if program.main.arguments.len() == inputs.len() {
Ok(())
} else {
@ -134,26 +137,136 @@ impl Interpreter {
})
}
}
pub fn execute_solver<T: Field>(
&self,
solver: &Solver,
inputs: &[T],
) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = solver.get_signature();
assert!(inputs.len() == expected_input_count);
let res = match solver {
Solver::ConditionEq => match inputs[0].is_zero() {
true => vec![T::zero(), T::one()],
false => vec![
T::one(),
T::one().checked_div(&inputs[0]).unwrap_or_else(T::one),
],
},
Solver::Bits(bit_width) => {
let mut num = inputs[0].clone();
let mut res = vec![];
for i in (0..*bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
} else {
res.push(T::zero());
}
}
res
}
Solver::Xor => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - T::from(2) * x * y]
}
Solver::Or => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - x * y]
}
// res = b * c - (2b * c - b - c) * (a)
Solver::ShaAndXorAndXorAnd => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a]
}
// res = a(b - c) + c
Solver::ShaCh => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![a * (b - c.clone()) + c]
}
Solver::Div => vec![inputs[0]
.clone()
.checked_div(&inputs[1])
.unwrap_or_else(T::one)],
Solver::EuclideanDiv => {
use num::CheckedDiv;
let n = inputs[0].clone().to_biguint();
let d = inputs[1].clone().to_biguint();
let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into());
let r = n - d * &q;
vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()]
}
#[cfg(feature = "bellman")]
Solver::Sha256Round => {
use zokrates_field::Bn128Field;
assert_eq!(T::id(), Bn128Field::id());
let i = &inputs[0..512];
let h = &inputs[512..];
let to_fr = |x: &T| {
use pairing_ce::ff::{PrimeField, ScalarEngine};
let s = x.to_dec_string();
<Bn256 as ScalarEngine>::Fr::from_str(&s).unwrap()
};
let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect();
let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect();
assert_eq!(h.len(), 256);
generate_sha256_round_witness::<Bn256>(&i, &h)
.into_iter()
.map(|x| {
use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr};
let mut res: Vec<u8> = vec![];
x.into_repr().write_le(&mut res).unwrap();
T::from_byte_vector(res)
})
.collect()
}
};
assert_eq!(res.len(), expected_output_count);
Ok(res)
}
}
#[derive(Debug)]
pub struct EvaluationError;
impl<T: Field> LinComb<T> {
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, ()> {
fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, EvaluationError> {
self.0
.iter()
.map(|(var, mult)| witness.get(var).map(|v| v.clone() * mult).ok_or(())) // get each term
.map(|(var, mult)| {
witness
.get(var)
.map(|v| v.clone() * mult)
.ok_or(EvaluationError)
}) // get each term
.collect::<Result<Vec<_>, _>>() // fail if any term isn't found
.map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum
}
fn is_assignee<U>(&self, witness: &BTreeMap<FlatVariable, U>) -> bool {
self.0.iter().count() == 1
&& self.0.iter().next().unwrap().1 == T::from(1)
&& !witness.contains_key(&self.0.iter().next().unwrap().0)
&& self.0.get(0).unwrap().1 == T::from(1)
&& !witness.contains_key(&self.0.get(0).unwrap().0)
}
}
impl<T: Field> QuadComb<T> {
pub fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, ()> {
pub fn evaluate(&self, witness: &BTreeMap<FlatVariable, T>) -> Result<T, EvaluationError> {
let left = self.left.evaluate(&witness)?;
let right = self.right.evaluate(&witness)?;
Ok(left * right)
@ -192,3 +305,83 @@ impl fmt::Debug for Error {
write!(f, "{}", self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
mod eq_condition {
// Wanted: (Y = (X != 0) ? 1 : 0)
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
use super::*;
#[test]
fn execute() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![0];
let interpreter = Interpreter::default();
let r = interpreter
.execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
#[test]
fn execute_non_eq() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![1];
let interpreter = Interpreter::default();
let r = interpreter
.execute_solver(
&cond_eq,
&inputs
.iter()
.map(|&i| Bn128Field::from(i))
.collect::<Vec<_>>(),
)
.unwrap();
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
}
#[test]
fn bits_of_one() {
let inputs = vec![Bn128Field::from(1)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
assert_eq!(res[253], Bn128Field::from(1));
for i in 0..253 {
assert_eq!(res[i], Bn128Field::from(0));
}
}
#[test]
fn bits_of_42() {
let inputs = vec![Bn128Field::from(42)];
let interpreter = Interpreter::default();
let res = interpreter
.execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs)
.unwrap();
assert_eq!(res[253], Bn128Field::from(0));
assert_eq!(res[252], Bn128Field::from(1));
assert_eq!(res[251], Bn128Field::from(0));
assert_eq!(res[250], Bn128Field::from(1));
assert_eq!(res[249], Bn128Field::from(0));
assert_eq!(res[248], Bn128Field::from(1));
assert_eq!(res[247], Bn128Field::from(0));
}
}

View file

@ -3,6 +3,7 @@ use crate::flat_absy::FlatVariable;
use crate::solvers::Solver;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::hash::Hash;
use zokrates_field::Field;
mod expression;
@ -19,26 +20,12 @@ pub use self::serialize::ProgEnum;
pub use self::interpreter::{Error, ExecutionResult, Interpreter};
pub use self::witness::Witness;
#[derive(Debug, Serialize, Deserialize, Clone, Hash)]
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
pub enum Statement<T> {
Constraint(QuadComb<T>, LinComb<T>),
Directive(Directive<T>),
}
impl<T: Field> PartialEq for Statement<T> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Statement::Constraint(l1, r1), Statement::Constraint(l2, r2)) => {
l1.eq(l2) && r1.eq(r2)
}
(Statement::Directive(d1), Statement::Directive(d2)) => d1.eq(d2),
_ => false,
}
}
}
impl<T: Field> Eq for Statement<T> {}
impl<T: Field> Statement<T> {
pub fn definition<U: Into<QuadComb<T>>>(v: FlatVariable, e: U) -> Self {
Statement::Constraint(e.into(), v.into())
@ -49,23 +36,13 @@ impl<T: Field> Statement<T> {
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Hash)]
#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct Directive<T> {
pub inputs: Vec<QuadComb<T>>,
pub outputs: Vec<FlatVariable>,
pub solver: Solver,
}
impl<T: Field> PartialEq for Directive<T> {
fn eq(&self, other: &Self) -> bool {
self.inputs.eq(&other.inputs)
&& self.outputs.eq(&other.outputs)
&& self.solver.eq(&other.solver)
}
}
impl<T: Field> Eq for Directive<T> {}
impl<T: Field> fmt::Display for Directive<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
@ -95,7 +72,7 @@ impl<T: Field> fmt::Display for Statement<T> {
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
pub struct Function<T> {
pub id: String,
pub statements: Vec<Statement<T>>,
@ -103,15 +80,6 @@ pub struct Function<T> {
pub returns: Vec<FlatVariable>,
}
impl<T: Field> PartialEq for Function<T> {
fn eq(&self, other: &Self) -> bool {
self.id.eq(&other.id)
&& self.statements.eq(&other.statements)
&& self.arguments.eq(&other.arguments)
&& self.returns.eq(&other.returns)
}
}
impl<T: Field> fmt::Display for Function<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
@ -138,27 +106,18 @@ impl<T: Field> fmt::Display for Function<T> {
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
pub struct Prog<T> {
pub main: Function<T>,
pub private: Vec<bool>,
}
impl<T: Field> PartialEq for Prog<T> {
fn eq(&self, other: &Self) -> bool {
self.main.eq(&other.main) && self.private.eq(&other.private)
}
}
impl<T: Field> Prog<T> {
pub fn constraint_count(&self) -> usize {
self.main
.statements
.iter()
.filter(|s| match s {
Statement::Constraint(..) => true,
_ => false,
})
.filter(|s| matches!(s, Statement::Constraint(..)))
.count()
}

View file

@ -16,9 +16,9 @@ pub enum ProgEnum {
impl<T: Field> Prog<T> {
pub fn serialize<W: Write>(&self, mut w: W) {
w.write(ZOKRATES_MAGIC).unwrap();
w.write(ZOKRATES_VERSION_1).unwrap();
w.write(&T::id()).unwrap();
w.write_all(ZOKRATES_MAGIC).unwrap();
w.write_all(ZOKRATES_VERSION_1).unwrap();
w.write_all(&T::id()).unwrap();
serialize_into(&mut w, self, Infinite).unwrap();
}

View file

@ -19,7 +19,7 @@ impl fmt::Display for Error {
}
}
pub fn process_macros<'ast, T: Field>(file: File<'ast>) -> Result<File<'ast>, Error> {
pub fn process_macros<T: Field>(file: File) -> Result<File, Error> {
match &file.pragma {
Some(pragma) => {
if T::name() != pragma.curve.name {

View file

@ -0,0 +1,10 @@
use crate::ir::{folder::Folder, LinComb};
use zokrates_field::Field;
pub struct Canonicalizer;
impl<T: Field> Folder<T> for Canonicalizer {
fn fold_linear_combination(&mut self, l: LinComb<T>) -> LinComb<T> {
l.into_canonical().into()
}
}

View file

@ -12,10 +12,10 @@
use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::folder::*;
use crate::ir::*;
use crate::optimizer::canonicalizer::Canonicalizer;
use crate::solvers::Solver;
use std::collections::hash_map::{Entry, HashMap};
use zokrates_field::Field;
#[derive(Debug)]
pub struct DirectiveOptimizer<T: Field> {
calls: HashMap<(Solver, Vec<QuadComb<T>>), Vec<FlatVariable>>,
@ -37,6 +37,23 @@ impl<T: Field> DirectiveOptimizer<T> {
}
impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
fn fold_function(&mut self, f: Function<T>) -> Function<T> {
// in order to correcty identify duplicates, we need to first canonicalize the statements
let mut canonicalizer = Canonicalizer;
let f = Function {
statements: f
.statements
.into_iter()
.flat_map(|s| canonicalizer.fold_statement(s))
.collect(),
..f
};
fold_function(self, f)
}
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
match s {
Statement::Directive(d) => {
@ -49,7 +66,7 @@ impl<T: Field> Folder<T> for DirectiveOptimizer<T> {
}
Entry::Occupied(e) => {
self.substitution
.extend(d.outputs.into_iter().zip(e.get().into_iter().cloned()));
.extend(d.outputs.into_iter().zip(e.get().iter().cloned()));
vec![]
}
}

View file

@ -1,7 +1,8 @@
//! Module containing the `DuplicateOptimizer` to remove duplicate constraints
use crate::ir::folder::Folder;
use crate::ir::folder::*;
use crate::ir::*;
use crate::optimizer::canonicalizer::Canonicalizer;
use std::collections::{hash_map::DefaultHasher, HashSet};
use zokrates_field::Field;
@ -33,6 +34,22 @@ impl DuplicateOptimizer {
}
impl<T: Field> Folder<T> for DuplicateOptimizer {
fn fold_function(&mut self, f: Function<T>) -> Function<T> {
// in order to correcty identify duplicates, we need to first canonicalize the statements
let mut canonicalizer = Canonicalizer;
let f = Function {
statements: f
.statements
.into_iter()
.flat_map(|s| canonicalizer.fold_statement(s))
.collect(),
..f
};
fold_function(self, f)
}
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
let hashed = hash(&s);
let result = match self.seen.get(&hashed) {
@ -120,7 +137,7 @@ mod tests {
main: Function {
id: "main".to_string(),
statements: vec![
constraint.clone(),
constraint,
Statement::Constraint(
QuadComb::from_linear_combinations(
LinComb::summand(3, FlatVariable::new(42)),

View file

@ -4,6 +4,7 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod canonicalizer;
mod directive;
mod duplicate;
mod redefinition;
@ -26,7 +27,6 @@ impl<T: Field> Prog<T> {
// // deduplicate directives which take the same input
let r = DirectiveOptimizer::optimize(r);
// remove duplicate constraints
let r = DuplicateOptimizer::optimize(r);
r
DuplicateOptimizer::optimize(r)
}
}

View file

@ -40,7 +40,6 @@ use crate::flat_absy::flat_variable::FlatVariable;
use crate::ir::folder::{fold_function, Folder};
use crate::ir::LinComb;
use crate::ir::*;
use crate::solvers::Executable;
use std::collections::{HashMap, HashSet};
use zokrates_field::Field;
@ -53,7 +52,7 @@ pub struct RedefinitionOptimizer<T: Field> {
}
impl<T: Field> RedefinitionOptimizer<T> {
fn new() -> RedefinitionOptimizer<T> {
fn new() -> Self {
RedefinitionOptimizer {
substitution: HashMap::new(),
ignore: HashSet::new(),
@ -72,84 +71,77 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
let quad = self.fold_quadratic_combination(quad);
let lin = self.fold_linear_combination(lin);
let (keep_constraint, to_insert, to_ignore) = match lin.try_summand() {
// if the right side is a single variable
Some((variable, coefficient)) => {
match self.ignore.contains(&variable) {
// if the variable isn't tagged as ignored
false => match self.substitution.get(&variable) {
// if the variable is already defined
Some(_) => (true, None, None),
// if the variable is not defined yet
None => match quad.try_linear() {
// if the left side is linear
Some(l) => (false, Some((variable, l / &coefficient)), None),
// if the left side isn't linear
None => (true, None, Some(variable)),
},
},
true => (true, None, None),
}
}
None => (true, None, None),
if lin.is_zero() {
return vec![Statement::Constraint(quad, lin)];
}
let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0)
|| self.substitution.contains_key(&lin.0[0].0)
{
true => (Some(Statement::Constraint(quad, lin)), None, None),
false => match lin.try_summand() {
// if the right side is a single variable
Ok((variable, coefficient)) => match quad.try_linear() {
// if the left side is linear
Ok(l) => (None, Some((variable, l / &coefficient)), None),
// if the left side isn't linear
Err(quad) => (
Some(Statement::Constraint(
quad,
LinComb::summand(coefficient, variable),
)),
None,
Some(variable),
),
},
Err(l) => (Some(Statement::Constraint(quad, l)), None, None),
},
};
// insert into the ignored set
match to_ignore {
Some(v) => {
self.ignore.insert(v);
}
None => {}
if let Some(v) = to_ignore {
self.ignore.insert(v);
}
// insert into the substitution map
match to_insert {
Some((k, v)) => {
self.substitution.insert(k, v.into_canonical());
}
None => {}
if let Some((k, v)) = to_insert {
self.substitution.insert(k, v.into_canonical());
};
// decide whether the constraint should be kept
match keep_constraint {
false => vec![],
true => vec![Statement::Constraint(quad, lin)],
match constraint {
Some(c) => vec![c],
_ => vec![],
}
}
Statement::Directive(d) => {
let d = self.fold_directive(d);
// check if the inputs are constants, ie reduce to the form `coeff * ~one`
let inputs = d
let inputs: Vec<_> = d
.inputs
.into_iter()
// we need to reduce to the canonical form to interpret `a + 1 - a` as `1`
.map(|i| i.reduce())
.map(|q| match q.try_linear() {
Some(l) => match l.0.len() {
// 0 is constant and can be represented by an empty lincomb
0 => Ok(T::from(0)),
_ => l
// try to match to a single summand `coeff * v`
.try_summand()
.map(|(variable, coefficient)| match variable {
// v must be ~one
v if v == FlatVariable::one() => Ok(coefficient),
_ => Err(LinComb::summand(coefficient, variable).into()),
})
.unwrap_or(Err(l.into())),
},
None => Err(q),
.map(|q| {
match q
.try_linear()
.map(|l| l.try_constant().map_err(|l| l.into()))
{
Ok(r) => r,
Err(e) => Err(e),
}
})
.collect::<Vec<Result<T, QuadComb<T>>>>();
match inputs.iter().all(|r| r.is_ok()) {
match inputs.iter().all(|i| i.is_ok()) {
true => {
// unwrap inputs to their constant value
let inputs = inputs.into_iter().map(|i| i.unwrap()).collect();
let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect();
// run the solver
let outputs = d.solver.execute(&inputs).unwrap();
let outputs = Interpreter::default()
.execute_solver(&d.solver, &inputs)
.unwrap();
assert_eq!(outputs.len(), d.outputs.len());
// insert the results in the substitution
@ -160,8 +152,8 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
vec![]
}
false => {
// reconstruct the input expressions
let inputs = inputs
//reconstruct the input expressions
let inputs: Vec<_> = inputs
.into_iter()
.map(|i| {
i.map(|v| LinComb::summand(v, FlatVariable::one()).into())
@ -183,8 +175,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
match lc
.0
.iter()
.find(|(variable, _)| self.substitution.get(&variable).is_some())
.is_some()
.any(|(variable, _)| self.substitution.get(&variable).is_some())
{
true =>
// for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is
@ -194,7 +185,7 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
self.substitution
.get(&variable)
.map(|l| LinComb::from(l.clone()) * &coefficient)
.unwrap_or(LinComb::summand(coefficient, variable))
.unwrap_or_else(|| LinComb::summand(coefficient, variable))
})
.fold(LinComb::zero(), |acc, x| acc + x)
}
@ -209,9 +200,6 @@ impl<T: Field> Folder<T> for RedefinitionOptimizer<T> {
}
fn fold_function(&mut self, fun: Function<T>) -> Function<T> {
self.substitution.drain();
self.ignore.drain();
// to prevent the optimiser from replacing outputs, add them to the ignored set
self.ignore.extend(fun.returns.iter().cloned());
@ -378,7 +366,7 @@ mod tests {
// ->
// def main(x, y) -> (1):
// 6*x + 6*y == 6*x + 6*y // will be eliminated as a tautology
// 1*x + 1*y + 2*x + 2*y + 3*x + 3*y == 6*x + 6*y // will be eliminated as a tautology
// return 6*x + 6*y
let x = FlatVariable::new(0);
@ -412,7 +400,15 @@ mod tests {
LinComb::summand(6, x) + LinComb::summand(6, y),
LinComb::summand(6, x) + LinComb::summand(6, y),
),
Statement::definition(r, LinComb::summand(6, x) + LinComb::summand(6, y)),
Statement::definition(
r,
LinComb::summand(1, x)
+ LinComb::summand(1, y)
+ LinComb::summand(2, x)
+ LinComb::summand(2, y)
+ LinComb::summand(3, x)
+ LinComb::summand(3, y),
),
],
returns: vec![r],
};

View file

@ -25,17 +25,16 @@ impl TautologyOptimizer {
impl<T: Field> Folder<T> for TautologyOptimizer {
fn fold_statement(&mut self, s: Statement<T>) -> Vec<Statement<T>> {
match s {
Statement::Constraint(quad, lin) => {
match quad.try_linear() {
Some(l) => {
if l == lin {
return vec![];
}
Statement::Constraint(quad, lin) => match quad.try_linear() {
Ok(l) => {
if l == lin {
vec![]
} else {
vec![Statement::Constraint(l.into(), lin)]
}
None => {}
}
vec![Statement::Constraint(quad, lin)]
}
Err(quad) => vec![Statement::Constraint(quad, lin)],
},
_ => fold_statement(self, s),
}
}

View file

@ -78,7 +78,7 @@ impl<T: Field + ArkFieldExtensions + NotBw6_761Field> Backend<T, GM17> for Ark {
query: vk
.query
.into_iter()
.map(|g1| serialization::to_g1::<T>(g1))
.map(serialization::to_g1::<T>)
.collect(),
};
@ -172,7 +172,7 @@ impl Backend<Bw6_761Field, GM17> for Ark {
query: vk
.query
.into_iter()
.map(|g1| serialization::to_g1::<Bw6_761Field>(g1))
.map(serialization::to_g1::<Bw6_761Field>)
.collect(),
};

View file

@ -49,42 +49,33 @@ fn ark_combination<T: Field + ArkFieldExtensions>(
cs: &mut ConstraintSystem<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>,
symbols: &mut BTreeMap<FlatVariable, Variable>,
witness: &mut Witness<T>,
) -> Result<
LinearCombination<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>,
SynthesisError,
> {
let lc =
l.0.into_iter()
.map(|(k, v)| {
(
v.into_ark(),
symbols
.entry(k)
.or_insert_with(|| {
match k.is_output() {
true => cs.new_input_variable(|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_ark())
}),
false => cs.new_witness_variable(|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_ark())
}),
}
.unwrap()
})
.clone(),
)
})
.fold(LinearCombination::zero(), |acc, e| acc + e);
Ok(lc)
) -> LinearCombination<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr> {
l.0.into_iter()
.map(|(k, v)| {
(
v.into_ark(),
*symbols.entry(k).or_insert_with(|| {
match k.is_output() {
true => cs.new_input_variable(|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_ark())
}),
false => cs.new_witness_variable(|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_ark())
}),
}
.unwrap()
}),
)
})
.fold(LinearCombination::zero(), |acc, e| acc + e)
}
impl<T: Field + ArkFieldExtensions> Prog<T> {
@ -96,7 +87,7 @@ impl<T: Field + ArkFieldExtensions> Prog<T> {
// mapping from IR variables
let mut symbols = BTreeMap::new();
let mut witness = witness.unwrap_or(Witness::empty());
let mut witness = witness.unwrap_or_else(Witness::empty);
assert!(symbols.insert(FlatVariable::one(), ConstraintSystem::<<<T as ArkFieldExtensions>::ArkEngine as PairingEngine>::Fr>::one()).is_none());
@ -127,37 +118,34 @@ impl<T: Field + ArkFieldExtensions> Prog<T> {
}),
}
.unwrap();
(var.clone(), wire)
(*var, wire)
}),
);
let main = self.main;
for statement in main.statements {
match statement {
Statement::Constraint(quad, lin) => {
let a = ark_combination(
quad.left.clone().into_canonical(),
&mut cs,
&mut symbols,
&mut witness,
)?;
let b = ark_combination(
quad.right.clone().into_canonical(),
&mut cs,
&mut symbols,
&mut witness,
)?;
let c = ark_combination(
lin.into_canonical(),
&mut cs,
&mut symbols,
&mut witness,
)?;
if let Statement::Constraint(quad, lin) = statement {
let a = ark_combination(
quad.left.clone().into_canonical(),
&mut cs,
&mut symbols,
&mut witness,
);
let b = ark_combination(
quad.right.clone().into_canonical(),
&mut cs,
&mut symbols,
&mut witness,
);
let c = ark_combination(
lin.into_canonical(),
&mut cs,
&mut symbols,
&mut witness,
);
cs.enforce_constraint(a, b, c)?;
}
_ => {}
cs.enforce_constraint(a, b, c)?;
}
}

View file

@ -81,7 +81,7 @@ impl<T: Field + BellmanFieldExtensions> Backend<T, G16> for Bellman {
ic: vk
.gamma_abc
.into_iter()
.map(|g1| serialization::to_g1::<T>(g1))
.map(serialization::to_g1::<T>)
.collect(),
};

View file

@ -51,34 +51,31 @@ fn bellman_combination<T: BellmanFieldExtensions, CS: ConstraintSystem<T::Bellma
.map(|(k, v)| {
(
v.into_bellman(),
symbols
.entry(k)
.or_insert_with(|| {
match k.is_output() {
true => cs.alloc_input(
|| format!("{}", k),
|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_bellman())
},
),
false => cs.alloc(
|| format!("{}", k),
|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_bellman())
},
),
}
.unwrap()
})
.clone(),
*symbols.entry(k).or_insert_with(|| {
match k.is_output() {
true => cs.alloc_input(
|| format!("{}", k),
|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_bellman())
},
),
false => cs.alloc(
|| format!("{}", k),
|| {
Ok(witness
.0
.remove(&k)
.ok_or(SynthesisError::AssignmentMissing)?
.into_bellman())
},
),
}
.unwrap()
}),
)
})
.fold(LinearCombination::zero(), |acc, e| acc + e)
@ -93,7 +90,7 @@ impl<T: BellmanFieldExtensions + Field> Prog<T> {
// mapping from IR variables
let mut symbols = BTreeMap::new();
let mut witness = witness.unwrap_or(Witness::empty());
let mut witness = witness.unwrap_or_else(Witness::empty);
assert!(symbols.insert(FlatVariable::one(), CS::one()).is_none());
@ -127,33 +124,29 @@ impl<T: BellmanFieldExtensions + Field> Prog<T> {
),
}
.unwrap();
(var.clone(), wire)
(*var, wire)
}),
);
let main = self.main;
for statement in main.statements {
match statement {
Statement::Constraint(quad, lin) => {
let a = &bellman_combination(
quad.left.into_canonical(),
cs,
&mut symbols,
&mut witness,
);
let b = &bellman_combination(
quad.right.into_canonical(),
cs,
&mut symbols,
&mut witness,
);
let c =
&bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness);
if let Statement::Constraint(quad, lin) = statement {
let a = &bellman_combination(
quad.left.into_canonical(),
cs,
&mut symbols,
&mut witness,
);
let b = &bellman_combination(
quad.right.into_canonical(),
cs,
&mut symbols,
&mut witness,
);
let c = &bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness);
cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c);
}
_ => {}
cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c);
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,5 @@
#[cfg(feature = "bellman")]
use pairing_ce::bn256::Bn256;
use serde::{Deserialize, Serialize};
use std::fmt;
#[cfg(feature = "bellman")]
use zokrates_embed::generate_sha256_round_witness;
use zokrates_field::{Bn128Field, Field};
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
pub enum Solver {
@ -48,168 +43,3 @@ impl Solver {
Solver::Bits(width)
}
}
pub trait Executable<T> {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String>;
}
impl<T: Field> Executable<T> for Solver {
fn execute(&self, inputs: &Vec<T>) -> Result<Vec<T>, String> {
let (expected_input_count, expected_output_count) = self.get_signature();
assert_eq!(inputs.len(), expected_input_count);
let res = match self {
Solver::ConditionEq => match inputs[0].is_zero() {
true => vec![T::zero(), T::one()],
false => vec![
T::one(),
T::one().checked_div(&inputs[0]).unwrap_or(T::one()),
],
},
Solver::Bits(bit_width) => {
let mut num = inputs[0].clone();
let mut res = vec![];
for i in (0..*bit_width).rev() {
if T::from(2).pow(i) <= num {
num = num - T::from(2).pow(i);
res.push(T::one());
} else {
res.push(T::zero());
}
}
res
}
Solver::Xor => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - T::from(2) * x * y]
}
Solver::Or => {
let x = inputs[0].clone();
let y = inputs[1].clone();
vec![x.clone() + y.clone() - x * y]
}
// res = b * c - (2b * c - b - c) * (a)
Solver::ShaAndXorAndXorAnd => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a]
}
// res = a(b - c) + c
Solver::ShaCh => {
let a = inputs[0].clone();
let b = inputs[1].clone();
let c = inputs[2].clone();
vec![a * (b - c.clone()) + c]
}
Solver::Div => vec![inputs[0]
.clone()
.checked_div(&inputs[1])
.unwrap_or(T::one())],
Solver::EuclideanDiv => {
use num::CheckedDiv;
let n = inputs[0].clone().to_biguint();
let d = inputs[1].clone().to_biguint();
let q = n.checked_div(&d).unwrap_or(0u32.into());
let r = n - d * &q;
vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()]
}
#[cfg(feature = "bellman")]
Solver::Sha256Round => {
assert_eq!(T::id(), Bn128Field::id());
let i = &inputs[0..512];
let h = &inputs[512..];
let to_fr = |x: &T| {
use pairing_ce::ff::{PrimeField, ScalarEngine};
let s = x.to_dec_string();
<Bn256 as ScalarEngine>::Fr::from_str(&s).unwrap()
};
let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect();
let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect();
assert_eq!(h.len(), 256);
generate_sha256_round_witness::<Bn256>(&i, &h)
.into_iter()
.map(|x| {
use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr};
let mut res: Vec<u8> = vec![];
x.into_repr().write_le(&mut res).unwrap();
T::from_byte_vector(res)
})
.collect()
}
};
assert_eq!(res.len(), expected_output_count);
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use zokrates_field::Bn128Field;
mod eq_condition {
// Wanted: (Y = (X != 0) ? 1 : 0)
// # Y = if X == 0 then 0 else 1 fi
// # M = if X == 0 then 1 else 1/X fi
use super::*;
#[test]
fn execute() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![0];
let r = cond_eq
.execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect())
.unwrap();
let res: Vec<Bn128Field> = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
#[test]
fn execute_non_eq() {
let cond_eq = Solver::ConditionEq;
let inputs = vec![1];
let r = cond_eq
.execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect())
.unwrap();
let res: Vec<Bn128Field> = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect();
assert_eq!(r, &res[..]);
}
}
#[test]
fn bits_of_one() {
let bits = Solver::Bits(Bn128Field::get_required_bits());
let inputs = vec![Bn128Field::from(1)];
let res = bits.execute(&inputs).unwrap();
assert_eq!(res[253], Bn128Field::from(1));
for i in 0..253 {
assert_eq!(res[i], Bn128Field::from(0));
}
}
#[test]
fn bits_of_42() {
let bits = Solver::Bits(Bn128Field::get_required_bits());
let inputs = vec![Bn128Field::from(42)];
let res = bits.execute(&inputs).unwrap();
assert_eq!(res[253], Bn128Field::from(0));
assert_eq!(res[252], Bn128Field::from(1));
assert_eq!(res[251], Bn128Field::from(0));
assert_eq!(res[250], Bn128Field::from(1));
assert_eq!(res[249], Bn128Field::from(0));
assert_eq!(res[248], Bn128Field::from(1));
assert_eq!(res[247], Bn128Field::from(0));
}
}

View file

@ -0,0 +1,134 @@
use crate::typed_absy::result_folder::*;
use crate::typed_absy::*;
use zokrates_field::Field;
pub struct BoundsChecker;
pub type Error = String;
impl BoundsChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
BoundsChecker.fold_program(p)
}
pub fn check_select<'ast, T: Field, U: Select<'ast, T>>(
&mut self,
array: ArrayExpression<'ast, T>,
index: UExpression<'ast, T>,
) -> Result<U, Error> {
let array = self.fold_array_expression(array)?;
let index = self.fold_uint_expression(index)?;
match (array.get_array_type().size.as_inner(), index.as_inner()) {
(UExpressionInner::Value(size), UExpressionInner::Value(index)) => {
if index >= size {
return Err(format!(
"Out of bounds access: {}[{}] but {} is of size {}",
array, index, array, size
));
}
}
_ => unreachable!(),
};
Ok(U::select(array, index))
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker {
type Error = Error;
fn fold_array_expression_inner(
&mut self,
ty: &ArrayType<'ast, T>,
e: ArrayExpressionInner<'ast, T>,
) -> Result<ArrayExpressionInner<'ast, T>, Self::Error> {
match e {
ArrayExpressionInner::Select(box array, box index) => self
.check_select::<_, ArrayExpression<_>>(array, index)
.map(|a| a.into_inner()),
ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = self.fold_array_expression(array)?;
let from = self.fold_uint_expression(from)?;
let to = self.fold_uint_expression(to)?;
match (
array.get_array_type().size.as_inner(),
from.as_inner(),
to.as_inner(),
) {
(
UExpressionInner::Value(size),
UExpressionInner::Value(from),
UExpressionInner::Value(to),
) => {
if from > to {
return Err(format!(
"Slice is created from an invalid range {}..{}",
from, to
));
}
if from > size {
return Err(format!("Lower bound {} of slice {}[{}..{}] is out of bounds for array of size {}", from, array, from, to, size));
}
if to > size {
return Err(format!("Upper bound {} of slice {}[{}..{}] is out of bounds for array of size {}", to, array, from, to, size));
}
}
_ => unreachable!(),
};
Ok(ArrayExpressionInner::Slice(box array, box from, box to))
}
e => fold_array_expression_inner(self, ty, e),
}
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType<'ast, T>,
e: StructExpressionInner<'ast, T>,
) -> Result<StructExpressionInner<'ast, T>, Self::Error> {
match e {
StructExpressionInner::Select(box array, box index) => self
.check_select::<_, StructExpression<_>>(array, index)
.map(|a| a.into_inner()),
e => fold_struct_expression_inner(self, ty, e),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::Select(box array, box index) => self.check_select(array, index),
e => fold_field_expression(self, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> Result<BooleanExpression<'ast, T>, Self::Error> {
match e {
BooleanExpression::Select(box array, box index) => self.check_select(array, index),
e => fold_boolean_expression(self, e),
}
}
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Self::Error> {
match e {
UExpressionInner::Select(box array, box index) => self
.check_select::<_, UExpression<_>>(array, index)
.map(|a| a.into_inner()),
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
}

View file

@ -1,31 +1,34 @@
use crate::typed_absy;
use crate::typed_absy::types::{StructType, UBitwidth};
use crate::typed_absy::types::UBitwidth;
use crate::zir;
use std::marker::PhantomData;
use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
pub struct Flattener<T: Field> {
phantom: PhantomData<T>,
}
fn flatten_identifier_rec<'a>(
id: zir::SourceIdentifier<'a>,
ty: &typed_absy::Type,
) -> Vec<zir::Variable<'a>> {
fn flatten_identifier_rec<'ast>(
id: zir::SourceIdentifier<'ast>,
ty: &typed_absy::types::ConcreteType,
) -> Vec<zir::Variable<'ast>> {
match ty {
typed_absy::Type::FieldElement => vec![zir::Variable {
typed_absy::ConcreteType::Int => unreachable!(),
typed_absy::ConcreteType::FieldElement => vec![zir::Variable {
id: zir::Identifier::Source(id),
_type: zir::Type::FieldElement,
}],
typed_absy::Type::Boolean => vec![zir::Variable {
typed_absy::types::ConcreteType::Boolean => vec![zir::Variable {
id: zir::Identifier::Source(id),
_type: zir::Type::Boolean,
}],
typed_absy::Type::Uint(bitwidth) => vec![zir::Variable {
typed_absy::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable {
id: zir::Identifier::Source(id),
_type: zir::Type::uint(bitwidth.to_usize()),
}],
typed_absy::Type::Array(array_type) => (0..array_type.size)
typed_absy::types::ConcreteType::Array(array_type) => (0..array_type.size)
.flat_map(|i| {
flatten_identifier_rec(
zir::SourceIdentifier::Select(box id.clone(), i),
@ -33,7 +36,7 @@ fn flatten_identifier_rec<'a>(
)
})
.collect(),
typed_absy::Type::Struct(members) => members
typed_absy::types::ConcreteType::Struct(members) => members
.iter()
.flat_map(|struct_member| {
flatten_identifier_rec(
@ -57,17 +60,6 @@ impl<'ast, T: Field> Flattener<T> {
fold_program(self, p)
}
fn fold_module(&mut self, p: typed_absy::TypedModule<'ast, T>) -> zir::ZirModule<'ast, T> {
fold_module(self, p)
}
fn fold_function_symbol(
&mut self,
s: typed_absy::TypedFunctionSymbol<'ast, T>,
) -> zir::ZirFunctionSymbol<'ast, T> {
fold_function_symbol(self, s)
}
fn fold_function(
&mut self,
f: typed_absy::TypedFunction<'ast, T>,
@ -75,9 +67,12 @@ impl<'ast, T: Field> Flattener<T> {
fold_function(self, f)
}
fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec<zir::Parameter<'ast>> {
fn fold_declaration_parameter(
&mut self,
p: typed_absy::DeclarationParameter<'ast>,
) -> Vec<zir::Parameter<'ast>> {
let private = p.private;
self.fold_variable(p.id)
self.fold_variable(p.id.try_into().unwrap())
.into_iter()
.map(|v| zir::Parameter { id: v, private })
.collect()
@ -87,10 +82,12 @@ impl<'ast, T: Field> Flattener<T> {
zir::SourceIdentifier::Basic(n)
}
fn fold_variable(&mut self, v: typed_absy::Variable<'ast>) -> Vec<zir::Variable<'ast>> {
fn fold_variable(&mut self, v: typed_absy::Variable<'ast, T>) -> Vec<zir::Variable<'ast>> {
let id = self.fold_name(v.id.clone());
let ty = v.get_type();
let ty = typed_absy::types::ConcreteType::try_from(ty).unwrap();
flatten_identifier_rec(id, &ty)
}
@ -102,36 +99,34 @@ impl<'ast, T: Field> Flattener<T> {
typed_absy::TypedAssignee::Identifier(v) => self.fold_variable(v),
typed_absy::TypedAssignee::Select(box a, box i) => {
use typed_absy::Typed;
let count = match a.get_type() {
typed_absy::Type::Array(array_ty) => array_ty.ty.get_primitive_count(),
let count = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap() {
typed_absy::ConcreteType::Array(array_ty) => array_ty.ty.get_primitive_count(),
_ => unreachable!(),
};
let a = self.fold_assignee(a);
match i {
typed_absy::FieldElementExpression::Number(n) => {
let index = n.to_dec_string().parse::<usize>().unwrap();
a[index * count..(index + 1) * count].to_vec()
match i.as_inner() {
typed_absy::UExpressionInner::Value(index) => {
a[*index as usize * count..(*index as usize + 1) * count].to_vec()
}
i => unreachable!("index {} not allowed, should be a constant", i),
i => unreachable!("index {:?} not allowed, should be a constant", i),
}
}
typed_absy::TypedAssignee::Member(box a, m) => {
use typed_absy::Typed;
let (offset, size) = match a.get_type() {
typed_absy::Type::Struct(struct_type) => {
struct_type
.members
.iter()
.fold((0, None), |(offset, size), member| match size {
Some(_) => (offset, size),
None => match m == member.id {
true => (offset, Some(member.ty.get_primitive_count())),
false => (offset + member.ty.get_primitive_count(), None),
},
})
}
let (offset, size) = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap()
{
typed_absy::ConcreteType::Struct(struct_type) => struct_type
.members
.iter()
.fold((0, None), |(offset, size), member| match size {
Some(_) => (offset, size),
None => match m == member.id {
true => (offset, Some(member.ty.get_primitive_count())),
false => (offset + member.ty.get_primitive_count(), None),
},
}),
_ => unreachable!(),
};
@ -151,6 +146,16 @@ impl<'ast, T: Field> Flattener<T> {
fold_statement(self, s)
}
fn fold_expression_or_spread(
&mut self,
e: typed_absy::TypedExpressionOrSpread<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match e {
typed_absy::TypedExpressionOrSpread::Expression(e) => self.fold_expression(e),
typed_absy::TypedExpressionOrSpread::Spread(s) => self.fold_array_expression(s.array),
}
}
fn fold_expression(
&mut self,
e: typed_absy::TypedExpression<'ast, T>,
@ -161,8 +166,9 @@ impl<'ast, T: Field> Flattener<T> {
}
typed_absy::TypedExpression::Boolean(e) => vec![self.fold_boolean_expression(e).into()],
typed_absy::TypedExpression::Uint(e) => vec![self.fold_uint_expression(e).into()],
typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e).into(),
typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e).into(),
typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e),
typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e),
typed_absy::TypedExpression::Int(_) => unreachable!(),
}
}
@ -185,26 +191,20 @@ impl<'ast, T: Field> Flattener<T> {
es: typed_absy::TypedExpressionList<'ast, T>,
) -> zir::ZirExpressionList<'ast, T> {
match es {
typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => {
zir::ZirExpressionList::FunctionCall(
self.fold_function_key(id),
typed_absy::TypedExpressionList::EmbedCall(embed, generics, arguments, _) => {
zir::ZirExpressionList::EmbedCall(
embed,
generics,
arguments
.into_iter()
.flat_map(|a| self.fold_expression(a))
.collect(),
vec![],
)
}
_ => unreachable!("should have been inlined"),
}
}
fn fold_function_key(
&mut self,
k: typed_absy::types::FunctionKey<'ast>,
) -> zir::types::FunctionKey<'ast> {
k.into()
}
fn fold_field_expression(
&mut self,
e: typed_absy::FieldElementExpression<'ast, T>,
@ -234,7 +234,7 @@ impl<'ast, T: Field> Flattener<T> {
fn fold_array_expression_inner(
&mut self,
ty: &typed_absy::Type,
ty: &typed_absy::types::ConcreteType,
size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
@ -242,26 +242,13 @@ impl<'ast, T: Field> Flattener<T> {
}
fn fold_struct_expression_inner(
&mut self,
ty: &StructType,
ty: &typed_absy::types::ConcreteStructType,
e: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
fold_struct_expression_inner(self, ty, e)
}
}
pub fn fold_module<'ast, T: Field>(
f: &mut Flattener<T>,
p: typed_absy::TypedModule<'ast, T>,
) -> zir::ZirModule<'ast, T> {
zir::ZirModule {
functions: p
.functions
.into_iter()
.map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun)))
.collect(),
}
}
pub fn fold_statement<'ast, T: Field>(
f: &mut Flattener<T>,
s: typed_absy::TypedStatement<'ast, T>,
@ -284,9 +271,7 @@ pub fn fold_statement<'ast, T: Field>(
}
typed_absy::TypedStatement::Declaration(v) => {
let v = f.fold_variable(v);
v.into_iter()
.map(|v| zir::ZirStatement::Declaration(v))
.collect()
v.into_iter().map(zir::ZirStatement::Declaration).collect()
}
typed_absy::TypedStatement::Assertion(e) => {
let e = f.fold_boolean_expression(e);
@ -302,19 +287,23 @@ pub fn fold_statement<'ast, T: Field>(
f.fold_expression_list(elist),
)]
}
typed_absy::TypedStatement::PushCallLog(..) => vec![],
typed_absy::TypedStatement::PopCallLog => vec![],
}
}
pub fn fold_array_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
t: &typed_absy::Type,
ty: &typed_absy::types::ConcreteType,
size: usize,
e: typed_absy::ArrayExpressionInner<'ast, T>,
array: typed_absy::ArrayExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match e {
match array {
typed_absy::ArrayExpressionInner::Identifier(id) => {
let variables =
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::array(t.clone(), size));
let variables = flatten_identifier_rec(
f.fold_name(id),
&typed_absy::types::ConcreteType::array((ty.clone(), size)),
);
variables
.into_iter()
.map(|v| match v._type {
@ -326,10 +315,16 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
})
.collect()
}
typed_absy::ArrayExpressionInner::Value(exprs) => exprs
.into_iter()
.flat_map(|e| f.fold_expression(e))
.collect(),
typed_absy::ArrayExpressionInner::Value(exprs) => {
let exprs: Vec<_> = exprs
.into_iter()
.flat_map(|e| f.fold_expression_or_spread(e))
.collect();
assert_eq!(exprs.len(), size * ty.get_primitive_count());
exprs
}
typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(),
typed_absy::ArrayExpressionInner::IfElse(
box condition,
@ -369,40 +364,74 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count())
.map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum();
// we also need the size of this member
let size = t.get_primitive_count() * size;
let size = ty.get_primitive_count() * size;
s[offset..offset + size].to_vec()
}
typed_absy::ArrayExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
let index = f.fold_uint_expression(index);
match index {
zir::FieldElementExpression::Number(i) => {
let size = t.get_primitive_count() * size;
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
match index.into_inner() {
zir::UExpressionInner::Value(i) => {
let size = ty.clone().get_primitive_count() * size;
let start = i as usize * size;
let end = start + size;
array[start..end].to_vec()
}
_ => unreachable!(),
}
}
typed_absy::ArrayExpressionInner::Slice(box array, box from, box to) => {
let array = f.fold_array_expression(array);
let from = f.fold_uint_expression(from);
let to = f.fold_uint_expression(to);
match (from.into_inner(), to.into_inner()) {
(zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => {
assert_eq!(size, to.saturating_sub(from) as usize);
let element_size = ty.get_primitive_count();
let start = from as usize * element_size;
let end = to as usize * element_size;
array[start..end].to_vec()
}
_ => unreachable!(),
}
}
typed_absy::ArrayExpressionInner::Repeat(box e, box count) => {
let e = f.fold_expression(e);
let count = f.fold_uint_expression(count);
match count.into_inner() {
zir::UExpressionInner::Value(count) => {
vec![e; count as usize].into_iter().flatten().collect()
}
_ => unreachable!(),
}
}
}
}
pub fn fold_struct_expression_inner<'ast, T: Field>(
f: &mut Flattener<T>,
t: &StructType,
e: typed_absy::StructExpressionInner<'ast, T>,
ty: &typed_absy::types::ConcreteStructType,
struc: typed_absy::StructExpressionInner<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
match e {
match struc {
typed_absy::StructExpressionInner::Identifier(id) => {
let variables =
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::struc(t.clone()));
let variables = flatten_identifier_rec(
f.fold_name(id),
&typed_absy::types::ConcreteType::struc(ty.clone()),
);
variables
.into_iter()
.map(|v| match v._type {
@ -457,13 +486,18 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count())
.map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum();
// we also need the size of this member
let size = t
let size = ty
.iter()
.find(|member| member.id == id)
.cloned()
.unwrap()
.ty
.get_primitive_count();
@ -472,15 +506,12 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
}
typed_absy::StructExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
let index = f.fold_uint_expression(index);
match index {
zir::FieldElementExpression::Number(i) => {
let size = t
.iter()
.map(|m| m.ty.get_primitive_count())
.fold(0, |acc, current| acc + current);
let start = i.to_dec_string().parse::<usize>().unwrap() * size;
match index.into_inner() {
zir::UExpressionInner::Value(i) => {
let size: usize = ty.iter().map(|m| m.ty.get_primitive_count()).sum();
let start = i as usize * size;
let end = start + size;
array[start..end].to_vec()
}
@ -498,9 +529,12 @@ pub fn fold_field_expression<'ast, T: Field>(
typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n),
typed_absy::FieldElementExpression::Identifier(id) => {
zir::FieldElementExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::FieldElement)[0]
.id
.clone(),
flatten_identifier_rec(
f.fold_name(id),
&typed_absy::types::ConcreteType::FieldElement,
)[0]
.id
.clone(),
)
}
typed_absy::FieldElementExpression::Add(box e1, box e2) => {
@ -525,7 +559,7 @@ pub fn fold_field_expression<'ast, T: Field>(
}
typed_absy::FieldElementExpression::Pow(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
let e2 = f.fold_uint_expression(e2);
zir::FieldElementExpression::Pow(box e1, box e2)
}
typed_absy::FieldElementExpression::Neg(box e) => {
@ -556,26 +590,22 @@ pub fn fold_field_expression<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count())
.map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap()
}
typed_absy::FieldElementExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
let index = f.fold_uint_expression(index);
use std::convert::TryInto;
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
match index.into_inner() {
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
_ => unreachable!(""),
}
}
@ -589,7 +619,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
match e {
typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v),
typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Boolean)[0]
flatten_identifier_rec(f.fold_name(id), &typed_absy::types::ConcreteType::Boolean)[0]
.id
.clone(),
),
@ -665,25 +695,45 @@ pub fn fold_boolean_expression<'ast, T: Field>(
zir::BooleanExpression::UintEq(box e1, box e2)
}
typed_absy::BooleanExpression::Lt(box e1, box e2) => {
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Lt(box e1, box e2)
zir::BooleanExpression::FieldLt(box e1, box e2)
}
typed_absy::BooleanExpression::Le(box e1, box e2) => {
typed_absy::BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Le(box e1, box e2)
zir::BooleanExpression::FieldLe(box e1, box e2)
}
typed_absy::BooleanExpression::Gt(box e1, box e2) => {
typed_absy::BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Gt(box e1, box e2)
zir::BooleanExpression::FieldGt(box e1, box e2)
}
typed_absy::BooleanExpression::Ge(box e1, box e2) => {
typed_absy::BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
zir::BooleanExpression::Ge(box e1, box e2)
zir::BooleanExpression::FieldGe(box e1, box e2)
}
typed_absy::BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLt(box e1, box e2)
}
typed_absy::BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintLe(box e1, box e2)
}
typed_absy::BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGt(box e1, box e2)
}
typed_absy::BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
zir::BooleanExpression::UintGe(box e1, box e2)
}
typed_absy::BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
@ -714,25 +764,21 @@ pub fn fold_boolean_expression<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count())
.map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum();
use std::convert::TryInto;
s[offset].clone().try_into().unwrap()
}
typed_absy::BooleanExpression::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
let index = f.fold_uint_expression(index);
use std::convert::TryInto;
match index {
zir::FieldElementExpression::Number(i) => array
[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap(),
match index.into_inner() {
zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(),
_ => unreachable!(),
}
}
@ -755,9 +801,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
match e {
typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v),
typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier(
flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Uint(bitwidth))[0]
.id
.clone(),
flatten_identifier_rec(
f.fold_name(id),
&typed_absy::types::ConcreteType::Uint(bitwidth),
)[0]
.id
.clone(),
),
typed_absy::UExpressionInner::Add(box left, box right) => {
let left = f.fold_uint_expression(left);
@ -771,6 +820,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
zir::UExpressionInner::Sub(box left, box right)
}
typed_absy::UExpressionInner::FloorSub(..) => unreachable!(),
typed_absy::UExpressionInner::Mult(box left, box right) => {
let left = f.fold_uint_expression(left);
let right = f.fold_uint_expression(right);
@ -827,11 +877,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
typed_absy::UExpressionInner::Neg(box e) => {
let bitwidth = e.bitwidth();
f.fold_uint_expression(typed_absy::UExpression::sub(
typed_absy::UExpressionInner::Value(0).annotate(bitwidth),
e,
))
.into_inner()
f.fold_uint_expression(typed_absy::UExpressionInner::Value(0).annotate(bitwidth) - e)
.into_inner()
}
typed_absy::UExpressionInner::Pos(box e) => {
@ -844,16 +891,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
}
typed_absy::UExpressionInner::Select(box array, box index) => {
let array = f.fold_array_expression(array);
let index = f.fold_field_expression(index);
let index = f.fold_uint_expression(index);
use std::convert::TryInto;
match index {
zir::FieldElementExpression::Number(i) => {
let e: zir::UExpression<_> = array[i.to_dec_string().parse::<usize>().unwrap()]
.clone()
.try_into()
.unwrap();
match index.into_inner() {
zir::UExpressionInner::Value(i) => {
let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap();
e.into_inner()
}
_ => unreachable!(),
@ -867,11 +909,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
let offset: usize = members
.iter()
.take_while(|member| member.id != id)
.map(|member| member.ty.get_primitive_count())
.map(|member| {
typed_absy::types::ConcreteType::try_from(*member.ty.clone())
.unwrap()
.get_primitive_count()
})
.sum();
use std::convert::TryInto;
let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap();
res.into_inner()
@ -893,14 +937,18 @@ pub fn fold_function<'ast, T: Field>(
arguments: fun
.arguments
.into_iter()
.flat_map(|a| f.fold_parameter(a))
.flat_map(|a| f.fold_declaration_parameter(a))
.collect(),
statements: fun
.statements
.into_iter()
.flat_map(|s| f.fold_statement(s))
.collect(),
signature: fun.signature.into(),
signature: typed_absy::types::ConcreteSignature::try_from(
typed_absy::types::Signature::<T>::try_from(fun.signature).unwrap(),
)
.unwrap()
.into(),
}
}
@ -908,41 +956,46 @@ pub fn fold_array_expression<'ast, T: Field>(
f: &mut Flattener<T>,
e: typed_absy::ArrayExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner())
let size = match e.size().into_inner() {
typed_absy::UExpressionInner::Value(v) => v,
_ => unreachable!(),
} as usize;
f.fold_array_expression_inner(
&typed_absy::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(),
size,
e.into_inner(),
)
}
pub fn fold_struct_expression<'ast, T: Field>(
f: &mut Flattener<T>,
e: typed_absy::StructExpression<'ast, T>,
) -> Vec<zir::ZirExpression<'ast, T>> {
f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner())
}
pub fn fold_function_symbol<'ast, T: Field>(
f: &mut Flattener<T>,
s: typed_absy::TypedFunctionSymbol<'ast, T>,
) -> zir::ZirFunctionSymbol<'ast, T> {
match s {
typed_absy::TypedFunctionSymbol::Here(fun) => {
zir::ZirFunctionSymbol::Here(f.fold_function(fun))
}
typed_absy::TypedFunctionSymbol::There(key, module) => {
zir::ZirFunctionSymbol::There(f.fold_function_key(key), module)
} // by default, do not fold modules recursively
typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat),
}
f.fold_struct_expression_inner(
&typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(),
e.into_inner(),
)
}
pub fn fold_program<'ast, T: Field>(
f: &mut Flattener<T>,
p: typed_absy::TypedProgram<'ast, T>,
mut p: typed_absy::TypedProgram<'ast, T>,
) -> zir::ZirProgram<'ast, T> {
let main_module = p.modules.remove(&p.main).unwrap();
let main_function = main_module
.functions
.into_iter()
.find(|(key, _)| key.id == "main")
.unwrap()
.1;
let main_function = match main_function {
typed_absy::TypedFunctionSymbol::Here(f) => f,
_ => unreachable!(),
};
zir::ZirProgram {
modules: p
.modules
.into_iter()
.map(|(module_id, module)| (module_id, f.fold_module(module)))
.collect(),
main: p.main,
main: f.fold_function(main_function),
}
}

File diff suppressed because it is too large Load diff

View file

@ -4,69 +4,93 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod bounds_checker;
mod flat_propagation;
mod flatten_complex_types;
mod inline;
mod propagate_unroll;
mod propagation;
mod redefinition;
mod return_binder;
mod reducer;
mod uint_optimizer;
mod unconstrained_vars;
mod unroll;
mod variable_read_remover;
mod variable_write_remover;
use self::bounds_checker::BoundsChecker;
use self::flatten_complex_types::Flattener;
use self::inline::Inliner;
use self::propagate_unroll::PropagatedUnroller;
use self::propagation::Propagator;
use self::redefinition::RedefinitionOptimizer;
use self::return_binder::ReturnBinder;
use self::reducer::reduce_program;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_read_remover::VariableReadRemover;
use self::variable_write_remover::VariableWriteRemover;
use crate::flat_absy::FlatProg;
use crate::ir::Prog;
use crate::typed_absy::TypedProgram;
use crate::typed_absy::{abi::Abi, TypedProgram};
use crate::zir::ZirProgram;
use std::fmt;
use zokrates_field::Field;
pub trait Analyse {
fn analyse(self) -> Self;
}
#[derive(Debug)]
pub enum Error {
Reducer(self::reducer::Error),
OutOfBounds(self::bounds_checker::Error),
Propagation(self::propagation::Error),
}
impl From<self::reducer::Error> for Error {
fn from(e: self::reducer::Error) -> Self {
Error::Reducer(e)
}
}
impl From<self::bounds_checker::Error> for Error {
fn from(e: bounds_checker::Error) -> Self {
Error::OutOfBounds(e)
}
}
impl From<self::propagation::Error> for Error {
fn from(e: propagation::Error) -> Self {
Error::Propagation(e)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Reducer(e) => write!(f, "{}", e),
Error::OutOfBounds(e) => write!(f, "{}", e),
Error::Propagation(e) => write!(f, "{}", e),
}
}
}
impl<'ast, T: Field> TypedProgram<'ast, T> {
pub fn analyse(self) -> ZirProgram<'ast, T> {
// propagated unrolling
let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!("{}", e));
pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> {
let r = reduce_program(self).map_err(Error::from)?;
// return binding
let r = ReturnBinder::bind(r);
// inline
let r = Inliner::inline(r);
let abi = r.abi();
// propagate
let r = Propagator::propagate(r);
let r = Propagator::propagate(r).map_err(Error::from)?;
// optimize redefinitions
let r = RedefinitionOptimizer::optimize(r);
// remove assignment to variable index
let r = VariableWriteRemover::apply(r);
// remove variable access to complex types
let r = VariableReadRemover::apply(r);
// check array accesses are in bounds
let r = BoundsChecker::check(r).map_err(Error::from)?;
// convert to zir, removing complex types
let zir = Flattener::flatten(r);
// optimize uint expressions
let zir = UintOptimizer::optimize(zir);
zir
Ok((zir, abi))
}
}
@ -78,7 +102,6 @@ impl<T: Field> Analyse for FlatProg<T> {
impl<T: Field> Analyse for Prog<T> {
fn analyse(self) -> Self {
let r = UnconstrainedVariableDetector::detect(self);
r
UnconstrainedVariableDetector::detect(self)
}
}

View file

@ -1,236 +0,0 @@
//! Module containing iterative unrolling in order to unroll nested loops with variable bounds
//!
//! For example:
//! ```zokrates
//! for field i in 0..5 do
//! for field j in i..5 do
//! //
//! endfor
//! endfor
//! ```
//!
//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop
//!
//! This module does exactly that:
//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant
//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet*
//! - unroll again, this time the 5 inner loops all have constant bounds
//!
//! In the case that a loop bound cannot be reduced to a constant, we detect it by noticing that the unroll does
//! not make progress anymore.
use crate::static_analysis::propagation::Propagator;
use crate::static_analysis::unroll::{Output, Unroller};
use crate::typed_absy::TypedProgram;
use zokrates_field::Field;
pub struct PropagatedUnroller;
impl PropagatedUnroller {
pub fn unroll<'ast, T: Field>(
p: TypedProgram<'ast, T>,
) -> Result<TypedProgram<'ast, T>, &'static str> {
let mut blocked_at = None;
// unroll a first time, retrieving whether the unroll is complete
let mut unrolled = Unroller::unroll(p);
loop {
// conditions to exit the loop
unrolled = match unrolled {
Output::Complete(p) => return Ok(p),
Output::Incomplete(next, index) => {
if Some(index) == blocked_at {
return Err("Loop unrolling failed. This happened because a loop bound is not constant");
} else {
// update the index where we blocked
blocked_at = Some(index);
// propagate
let propagated = Propagator::propagate_verbose(next);
// unroll
Unroller::unroll(propagated)
}
}
};
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::typed_absy::types::{FunctionKey, Signature};
use crate::typed_absy::*;
use zokrates_field::Bn128Field;
#[test]
fn detect_non_constant_bound() {
let loops = vec![TypedStatement::For(
Variable::field_element("i"),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![],
)];
let statements = loops;
let p = TypedProgram {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
FunctionKey::with_id("main"),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
signature: Signature::new(),
statements,
}),
)]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
main: "main".into(),
};
assert!(PropagatedUnroller::unroll(p).is_err());
}
#[test]
fn for_loop() {
// for field i in 0..2
// for field j in i..2
// field foo = i + j
// should be unrolled to
// i_0 = 0
// j_0 = 0
// foo_0 = i_0 + j_0
// j_1 = 1
// foo_1 = i_0 + j_1
// i_1 = 1
// j_2 = 1
// foo_2 = i_1 + j_1
let s = TypedStatement::For(
Variable::field_element("i"),
FieldElementExpression::Number(Bn128Field::from(0)),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![TypedStatement::For(
Variable::field_element("j"),
FieldElementExpression::Identifier("i".into()),
FieldElementExpression::Number(Bn128Field::from(2)),
vec![
TypedStatement::Declaration(Variable::field_element("foo")),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element("foo")),
FieldElementExpression::Add(
box FieldElementExpression::Identifier("i".into()),
box FieldElementExpression::Identifier("j".into()),
)
.into(),
),
],
)],
);
let expected_statements = vec![
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("i").version(0),
)),
FieldElementExpression::Number(Bn128Field::from(0)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(0),
)),
FieldElementExpression::Number(Bn128Field::from(0)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(0),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
box FieldElementExpression::Identifier(Identifier::from("j").version(0)),
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(1),
)),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(1),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(0)),
box FieldElementExpression::Identifier(Identifier::from("j").version(1)),
)
.into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("i").version(1),
)),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("j").version(2),
)),
FieldElementExpression::Number(Bn128Field::from(1)).into(),
),
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::field_element(
Identifier::from("foo").version(2),
)),
FieldElementExpression::Add(
box FieldElementExpression::Identifier(Identifier::from("i").version(1)),
box FieldElementExpression::Identifier(Identifier::from("j").version(2)),
)
.into(),
),
];
let p = TypedProgram {
modules: vec![(
"main".into(),
TypedModule {
functions: vec![(
FunctionKey::with_id("main"),
TypedFunctionSymbol::Here(TypedFunction {
arguments: vec![],
signature: Signature::new(),
statements: vec![s],
}),
)]
.into_iter()
.collect(),
},
)]
.into_iter()
.collect(),
main: "main".into(),
};
let statements = match PropagatedUnroller::unroll(p).unwrap().modules
[std::path::Path::new("main")]
.functions[&FunctionKey::with_id("main")]
.clone()
{
TypedFunctionSymbol::Here(f) => f.statements,
_ => unreachable!(),
};
assert_eq!(statements, expected_statements);
}
}

File diff suppressed because it is too large Load diff

View file

@ -68,6 +68,6 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast> {
}
fn fold_name(&mut self, s: Identifier<'ast>) -> Identifier<'ast> {
self.identifiers.get(&s).map(|r| r.clone()).unwrap_or(s)
self.identifiers.get(&s).cloned().unwrap_or(s)
}
}

View file

@ -0,0 +1,235 @@
// The inlining phase takes a call site (fun::<gen>(args)) and inlines it:
// Given:
// ```
// def foo<n>(field a) -> field:
// a = a
// n = n
// return a
// ```
//
// The call site
// ```
// foo::<42>(x)
// ```
//
// Becomes
// ```
// # Call foo::<42> with a_0 := x
// n_0 = 42
// a_1 = a_0
// n_1 = n_0
// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1
// Notes:
// - The body of the function is in SSA form
// - The return value(s) are assigned to internal variables
use crate::embed::FlatEmbed;
use crate::static_analysis::reducer::Output;
use crate::static_analysis::reducer::ShallowTransformer;
use crate::static_analysis::reducer::Versions;
use crate::typed_absy::types::ConcreteGenericsAssignment;
use crate::typed_absy::CoreIdentifier;
use crate::typed_absy::Identifier;
use crate::typed_absy::TypedAssignee;
use crate::typed_absy::{
ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature,
Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression,
UExpressionInner, Variable,
};
use zokrates_field::Field;
pub enum InlineError<'ast, T> {
Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>),
Flat(
FlatEmbed,
Vec<u32>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
NonConstant(
DeclarationFunctionKey<'ast>,
Vec<Option<UExpression<'ast, T>>>,
Vec<TypedExpression<'ast, T>>,
Vec<Type<'ast, T>>,
),
}
fn get_canonical_function<'ast, T: Field>(
function_key: DeclarationFunctionKey<'ast>,
program: &TypedProgram<'ast, T>,
) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) {
match program
.modules
.get(&function_key.module)
.unwrap()
.functions
.iter()
.find(|(key, _)| function_key == **key)
.unwrap()
{
(_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program),
(key, s) => (key.clone(), s.clone()),
}
}
type InlineResult<'ast, T> = Result<
Output<(Vec<TypedStatement<'ast, T>>, Vec<TypedExpression<'ast, T>>), Vec<Versions<'ast>>>,
InlineError<'ast, T>,
>;
pub fn inline_call<'a, 'ast, T: Field>(
k: DeclarationFunctionKey<'ast>,
generics: Vec<Option<UExpression<'ast, T>>>,
arguments: Vec<TypedExpression<'ast, T>>,
output_types: Vec<Type<'ast, T>>,
program: &TypedProgram<'ast, T>,
versions: &'a mut Versions<'ast>,
) -> InlineResult<'ast, T> {
use std::convert::TryFrom;
use crate::typed_absy::Typed;
// we try to get concrete values for explicit generics
let generics_values: Vec<Option<u32>> = generics
.iter()
.map(|g| {
g.as_ref()
.map(|g| match g.as_inner() {
UExpressionInner::Value(v) => Ok(*v as u32),
_ => Err(()),
})
.transpose()
})
.collect::<Result<_, _>>()
.map_err(|_| {
InlineError::NonConstant(
k.clone(),
generics.clone(),
arguments.clone(),
output_types.clone(),
)
})?;
// we infer a signature based on inputs and outputs
// this is where we could handle explicit annotations
let inferred_signature = Signature::new()
.generics(generics.clone())
.inputs(arguments.iter().map(|a| a.get_type()).collect())
.outputs(output_types.clone());
// we try to get concrete values for the whole signature. if this fails we should propagate again
let inferred_signature = match ConcreteSignature::try_from(inferred_signature) {
Ok(s) => s,
Err(_) => {
return Err(InlineError::NonConstant(
k,
generics,
arguments,
output_types,
));
}
};
let (decl_key, symbol) = get_canonical_function(k.clone(), program);
// get an assignment of generics for this call site
let assignment: ConcreteGenericsAssignment<'ast> = k
.signature
.specialize(generics_values, &inferred_signature)
.map_err(|_| {
InlineError::Generic(
k.clone(),
ConcreteFunctionKey {
module: decl_key.module.clone(),
id: decl_key.id,
signature: inferred_signature.clone(),
},
)
})?;
let f = match symbol {
TypedFunctionSymbol::Here(f) => Ok(f),
TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat(
e,
e.generics(&assignment),
arguments.clone(),
output_types,
)),
_ => unreachable!(),
}?;
assert_eq!(f.arguments.len(), arguments.len());
let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) {
Output::Complete(v) => (v, None),
Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)),
};
let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone());
let input_bindings: Vec<TypedStatement<'ast, T>> = ssa_f
.arguments
.into_iter()
.zip(inferred_signature.inputs.clone())
.map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t))
.zip(arguments.clone())
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f
.statements
.into_iter()
.partition(|s| !matches!(s, TypedStatement::Return(..)));
assert_eq!(returns.len(), 1);
let returns = match returns.pop().unwrap() {
TypedStatement::Return(e) => e,
_ => unreachable!(),
};
let res: Vec<ConcreteVariable<'ast>> = inferred_signature
.outputs
.iter()
.enumerate()
.map(|(i, t)| {
ConcreteVariable::with_id_and_type(
Identifier::from(CoreIdentifier::Call(i)).version(
*versions
.entry(CoreIdentifier::Call(i).clone())
.and_modify(|e| *e += 1) // if it was already declared, we increment
.or_insert(0),
),
t.clone(),
)
})
.collect();
let expressions: Vec<TypedExpression<_>> = res
.iter()
.map(|v| TypedExpression::from(Variable::from(v.clone())))
.collect();
assert_eq!(res.len(), returns.len());
let output_bindings: Vec<TypedStatement<'ast, T>> = res
.into_iter()
.zip(returns)
.map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a))
.collect();
let pop_log = TypedStatement::PopCallLog;
let statements: Vec<_> = std::iter::once(call_log)
.chain(input_bindings)
.chain(statements)
.chain(output_bindings)
.chain(std::iter::once(pop_log))
.collect();
Ok(incomplete_data
.map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d))
.unwrap_or_else(|| Output::Complete((statements, expressions))))
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,58 +0,0 @@
use crate::typed_absy::folder::fold_statement;
use crate::typed_absy::identifier::CoreIdentifier;
use crate::typed_absy::*;
use zokrates_field::Field;
pub struct ReturnBinder;
impl ReturnBinder {
pub fn bind<'ast, T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
ReturnBinder {}.fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
match s {
TypedStatement::Return(exprs) => {
let ret_identifiers: Vec<Identifier<'ast>> = (0..exprs.len())
.map(|i| CoreIdentifier::Internal("RETURN", i).into())
.collect();
let ret_expressions: Vec<TypedExpression<'ast, T>> = exprs
.iter()
.zip(ret_identifiers.iter())
.map(|(e, i)| match e.get_type() {
Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(),
Type::Boolean => BooleanExpression::Identifier(i.clone()).into(),
Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone())
.annotate(*array_type.ty, array_type.size)
.into(),
Type::Struct(struct_type) => StructExpressionInner::Identifier(i.clone())
.annotate(struct_type)
.into(),
Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone())
.annotate(bitwidth)
.into(),
})
.collect();
exprs
.into_iter()
.zip(ret_identifiers.iter())
.map(|(e, i)| {
TypedStatement::Definition(
TypedAssignee::Identifier(Variable::with_id_and_type(
i.clone(),
e.get_type(),
)),
e,
)
})
.chain(std::iter::once(TypedStatement::Return(ret_expressions)))
.collect()
}
s => fold_statement(self, s),
}
}
}

Some files were not shown because too many files have changed in this diff Show more