diff --git a/.circleci/config.yml b/.circleci/config.yml index 27c42c82..170c7c1a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -4,6 +4,7 @@ jobs: build: docker: - image: zokrates/env:latest + resource_class: large steps: - checkout - run: @@ -28,6 +29,7 @@ jobs: test: docker: - image: zokrates/env:latest + resource_class: large steps: - checkout - run: @@ -42,6 +44,9 @@ jobs: - run: name: Check format command: cargo fmt --all -- --check + - run: + name: Run clippy + command: cargo clippy - run: name: Build command: WITH_LIBSNARK=1 RUSTFLAGS="-D warnings" ./build.sh @@ -80,6 +85,7 @@ jobs: docker: - image: zokrates/env:latest - image: trufflesuite/ganache-cli:next + resource_class: large steps: - checkout - run: diff --git a/Cargo.lock b/Cargo.lock index e1733e12..1e7b1891 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,7 +133,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e8cb28c2137af1ef058aa59616db3f7df67dbb70bf2be4ee6920008cc30d98c" dependencies = [ "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -145,7 +145,7 @@ dependencies = [ "num-bigint 0.4.0", "num-traits 0.2.14", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -205,7 +205,7 @@ checksum = "5ac3d78c750b01f5df5b2e76d106ed31487a93b3868f14a7f0eb3a74f45e1d8a" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -653,7 +653,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e98e2ad1a782e33928b96fc3948e7c355e5af34ba4de7670fe8bac2a3b2006d" dependencies = [ "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -664,7 +664,7 @@ checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -765,7 +765,7 @@ checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", "synstructure", ] @@ -809,7 +809,7 @@ dependencies = [ "num-traits 0.2.14", "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -1059,9 +1059,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.49" +version = "0.3.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc15e39392125075f60c95ba416f5381ff6c3a948ff02ab12464715adf56c821" +checksum = "2d99f9e3e84b8f67f846ef5b4cbbc3b1c29f6c759fcbce6f01aa0e73d932a24c" dependencies = [ "wasm-bindgen", ] @@ -1074,9 +1074,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8916b1f6ca17130ec6568feccee27c156ad12037880833a3b842a823236502e7" +checksum = "56d855069fafbb9b344c0f962150cd2c1187975cb1c22c1522c240d8c4986714" [[package]] name = "libgit2-sys" @@ -1370,7 +1370,7 @@ dependencies = [ "pest_meta", "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -1773,7 +1773,7 @@ checksum = "b093b7a2bb58203b5da3056c05b4ec1fed827dcfdb37347a8841695263b3d06d" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", ] [[package]] @@ -1866,9 +1866,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.64" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fd9d1e9976102a03c542daa2eff1b43f9d72306342f3f8b3ed5fb8908195d6f" +checksum = "6498a9efc342871f91cc2d0d694c674368b4ceb40f62b65a7a08c3792935e702" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", @@ -1883,7 +1883,7 @@ checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", "unicode-xid 0.2.1", ] @@ -2106,9 +2106,9 @@ checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" [[package]] name = "wasm-bindgen" -version = "0.2.72" +version = "0.2.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fe8f61dba8e5d645a4d8132dc7a0a66861ed5e1045d2c0ed940fab33bac0fbe" +checksum = "83240549659d187488f91f33c0f8547cbfef0b2088bc470c116d1d260ef623d9" dependencies = [ "cfg-if 1.0.0", "wasm-bindgen-macro", @@ -2116,24 +2116,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.72" +version = "0.2.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046ceba58ff062da072c7cb4ba5b22a37f00a302483f7e2a6cdc18fedbdc1fd3" +checksum = "ae70622411ca953215ca6d06d3ebeb1e915f0f6613e3b495122878d7ebec7dae" dependencies = [ "bumpalo", "lazy_static", "log", "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73157efb9af26fb564bb59a009afd1c7c334a44db171d280690d0c3faaec3468" +checksum = "81b8b767af23de6ac18bf2168b690bed2902743ddf0fb39252e36f9e2bfc63ea" dependencies = [ "cfg-if 1.0.0", "js-sys", @@ -2143,9 +2143,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.72" +version = "0.2.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ef9aa01d36cda046f797c57959ff5f3c615c9cc63997a8d545831ec7976819b" +checksum = "3e734d91443f177bfdb41969de821e15c516931c3c3db3d318fa1b68975d0f6f" dependencies = [ "quote 1.0.9", "wasm-bindgen-macro-support", @@ -2153,28 +2153,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.72" +version = "0.2.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96eb45c1b2ee33545a813a92dbb53856418bf7eb54ab34f7f7ff1448a5b3735d" +checksum = "d53739ff08c8a68b0fdbcd54c372b8ab800b1449ab3c9d706503bc7dd1621b2c" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.72" +version = "0.2.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7148f4696fb4960a346eaa60bbfb42a1ac4ebba21f750f75fc1375b098d5ffa" +checksum = "d9a543ae66aa233d14bb765ed9af4a33e81b8b58d1584cf1b47ff8cd0b9e4489" [[package]] name = "wasm-bindgen-test" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f002ea97b5abdb19aafd48cbb5a0a7f6931cf36ea05a0a46ccc95d9f4c2cf43" +checksum = "e972e914de63aa53bd84865e54f5c761bd274d48e5be3a6329a662c0386aa67a" dependencies = [ "console_error_panic_hook", "js-sys", @@ -2186,9 +2186,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a6c0bd3933daf64c78fc25a7452530f79fa7e21f77fa03d608d1e988a66735" +checksum = "ea6153a8f9bf24588e9f25c87223414fff124049f68d3a442a0f0eab4768a8b6" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", @@ -2196,9 +2196,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.49" +version = "0.3.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59fe19d70f5dacc03f6e46777213facae5ac3801575d56ca6cbd4c93dcd12310" +checksum = "a905d57e488fec8861446d3393670fb50d27a262344013181c2cdf9fff5481be" dependencies = [ "js-sys", "wasm-bindgen", @@ -2252,7 +2252,7 @@ checksum = "c3f369ddb18862aba61aa49bf31e74d29f0f162dec753063200e1dc084345d16" dependencies = [ "proc-macro2 1.0.24", "quote 1.0.9", - "syn 1.0.64", + "syn 1.0.67", "synstructure", ] diff --git a/asserts.zok b/asserts.zok new file mode 100644 index 00000000..2a2f1751 --- /dev/null +++ b/asserts.zok @@ -0,0 +1,7 @@ +def id() -> u32: + return N + +def main(): + assert(id::<5>() == 5) + assert(id::<6>() == 6) + return \ No newline at end of file diff --git a/changelogs/unreleased/695-schaeff b/changelogs/unreleased/695-schaeff new file mode 100644 index 00000000..539b5391 --- /dev/null +++ b/changelogs/unreleased/695-schaeff @@ -0,0 +1 @@ +Introduce constant generics for `u32` values. Introduce literal inference \ No newline at end of file diff --git a/changelogs/unreleased/754-schaeff b/changelogs/unreleased/754-schaeff new file mode 100644 index 00000000..ef7829ae --- /dev/null +++ b/changelogs/unreleased/754-schaeff @@ -0,0 +1 @@ +Make embed functions generic, enabling unpacking to any width at minimal cost \ No newline at end of file diff --git a/example.zok b/example.zok new file mode 100644 index 00000000..be818e32 --- /dev/null +++ b/example.zok @@ -0,0 +1,11 @@ +def foo(field[N] x) -> field[N]: + return x + +def bar(field[N] x) -> field[N]: + field[N] r = x + return r + +def main(field[3] x) -> field[2]: + field[2] z = foo(x)[0..2] + + return bar(z) \ No newline at end of file diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh new file mode 100755 index 00000000..dde98e91 --- /dev/null +++ b/scripts/benchmark.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Usage: benchmark.sh +# For MacOS: install gtime with homebrew `brew install gnu-time` + +cmd=$* +format="mem=%KK rss=%MK elapsed=%E cpu=%P cpu.sys=%S inputs=%I outputs=%O" + +if command -v gtime; then + gtime -f "$format" $cmd +else + /usr/bin/time -f "$format" $cmd +fi diff --git a/zokrates_abi/src/lib.rs b/zokrates_abi/src/lib.rs index 62b26a42..ef7f6c37 100644 --- a/zokrates_abi/src/lib.rs +++ b/zokrates_abi/src/lib.rs @@ -17,7 +17,7 @@ impl> Encode for Inputs { use std::collections::BTreeMap; use std::convert::TryFrom; use std::fmt; -use zokrates_core::typed_absy::{Type, UBitwidth}; +use zokrates_core::typed_absy::types::{ConcreteType, UBitwidth}; use zokrates_field::Field; @@ -94,18 +94,20 @@ impl fmt::Display for Value { } impl Value { - fn check(self, ty: Type) -> Result, String> { + fn check(self, ty: ConcreteType) -> Result, String> { match (self, ty) { - (Value::Field(f), Type::FieldElement) => Ok(CheckedValue::Field(f)), - (Value::U8(f), Type::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)), - (Value::U16(f), Type::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)), - (Value::U32(f), Type::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)), - (Value::Boolean(b), Type::Boolean) => Ok(CheckedValue::Boolean(b)), - (Value::Array(a), Type::Array(array_type)) => { - if a.len() != array_type.size { + (Value::Field(f), ConcreteType::FieldElement) => Ok(CheckedValue::Field(f)), + (Value::U8(f), ConcreteType::Uint(UBitwidth::B8)) => Ok(CheckedValue::U8(f)), + (Value::U16(f), ConcreteType::Uint(UBitwidth::B16)) => Ok(CheckedValue::U16(f)), + (Value::U32(f), ConcreteType::Uint(UBitwidth::B32)) => Ok(CheckedValue::U32(f)), + (Value::Boolean(b), ConcreteType::Boolean) => Ok(CheckedValue::Boolean(b)), + (Value::Array(a), ConcreteType::Array(array_type)) => { + let size = array_type.size; + + if a.len() != size as usize { Err(format!( "Expected array of size {}, found array of size {}", - array_type.size, + size, a.len() )) } else { @@ -116,15 +118,16 @@ impl Value { Ok(CheckedValue::Array(a)) } } - (Value::Struct(mut s), Type::Struct(members)) => { - if s.len() != members.len() { + (Value::Struct(mut s), ConcreteType::Struct(struc)) => { + if s.len() != struc.members_count() { Err(format!( "Expected {} member(s), found {}", - members.len(), + struc.members_count(), s.len() )) } else { - let s = members + let s = struc + .members .into_iter() .map(|member| { s.remove(&member.id) @@ -167,7 +170,7 @@ impl> Encode for CheckedValue { } impl Decode for CheckedValues { - type Expected = Vec; + type Expected = Vec; fn decode(raw: Vec, expected: Self::Expected) -> Self { CheckedValues( @@ -185,23 +188,24 @@ impl Decode for CheckedValues { } impl Decode for CheckedValue { - type Expected = Type; + type Expected = ConcreteType; fn decode(raw: Vec, expected: Self::Expected) -> Self { let mut raw = raw; match expected { - Type::FieldElement => CheckedValue::Field(raw.pop().unwrap()), - Type::Uint(UBitwidth::B8) => CheckedValue::U8( + ConcreteType::Int => unreachable!(), + ConcreteType::FieldElement => CheckedValue::Field(raw.pop().unwrap()), + ConcreteType::Uint(UBitwidth::B8) => CheckedValue::U8( u8::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(), ), - Type::Uint(UBitwidth::B16) => CheckedValue::U16( + ConcreteType::Uint(UBitwidth::B16) => CheckedValue::U16( u16::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(), ), - Type::Uint(UBitwidth::B32) => CheckedValue::U32( + ConcreteType::Uint(UBitwidth::B32) => CheckedValue::U32( u32::from_str_radix(&raw.pop().unwrap().to_dec_string(), 10).unwrap(), ), - Type::Boolean => { + ConcreteType::Boolean => { let v = raw.pop().unwrap(); CheckedValue::Boolean(if v == 0.into() { false @@ -211,12 +215,12 @@ impl Decode for CheckedValue { unreachable!() }) } - Type::Array(array_type) => CheckedValue::Array( + ConcreteType::Array(array_type) => CheckedValue::Array( raw.chunks(array_type.ty.get_primitive_count()) .map(|c| CheckedValue::decode(c.to_vec(), *array_type.ty.clone())) .collect(), ), - Type::Struct(members) => CheckedValue::Struct( + ConcreteType::Struct(members) => CheckedValue::Struct( members .into_iter() .scan(0, |state, member| { @@ -247,9 +251,9 @@ impl TryFrom for Values { match v { serde_json::Value::Array(a) => a .into_iter() - .map(|v| Value::try_from(v)) + .map(Value::try_from) .collect::>() - .map(|v| Values(v)), + .map(Values), v => Err(format!("Expected an array of values, found `{}`", v)), } } @@ -259,20 +263,22 @@ impl TryFrom for Value { type Error = String; fn try_from(v: serde_json::Value) -> Result, Self::Error> { match v { - serde_json::Value::String(s) => T::try_from_dec_str(&s) - .map(|v| Value::Field(v)) - .or_else(|_| match s.len() { - 4 => u8::from_str_radix(&s[2..], 16) - .map(|v| Value::U8(v)) - .map_err(|_| format!("Expected u8 value, found {}", s)), - 6 => u16::from_str_radix(&s[2..], 16) - .map(|v| Value::U16(v)) - .map_err(|_| format!("Expected u16 value, found {}", s)), - 10 => u32::from_str_radix(&s[2..], 16) - .map(|v| Value::U32(v)) - .map_err(|_| format!("Expected u32 value, found {}", s)), - _ => Err(format!("Cannot parse {} to any type", s)), - }), + serde_json::Value::String(s) => { + T::try_from_dec_str(&s) + .map(Value::Field) + .or_else(|_| match s.len() { + 4 => u8::from_str_radix(&s[2..], 16) + .map(Value::U8) + .map_err(|_| format!("Expected u8 value, found {}", s)), + 6 => u16::from_str_radix(&s[2..], 16) + .map(Value::U16) + .map_err(|_| format!("Expected u16 value, found {}", s)), + 10 => u32::from_str_radix(&s[2..], 16) + .map(Value::U32) + .map_err(|_| format!("Expected u32 value, found {}", s)), + _ => Err(format!("Cannot parse {} to any type", s)), + }) + } serde_json::Value::Bool(b) => Ok(Value::Boolean(b)), serde_json::Value::Number(n) => Err(format!( "Value `{}` isn't allowed, did you mean `\"{}\"`?", @@ -280,14 +286,14 @@ impl TryFrom for Value { )), serde_json::Value::Array(a) => a .into_iter() - .map(|v| Value::try_from(v)) + .map(Value::try_from) .collect::>() - .map(|v| Value::Array(v)), + .map(Value::Array), serde_json::Value::Object(o) => o .into_iter() .map(|(k, v)| Value::try_from(v).map(|v| (k, v))) .collect::, _>>() - .map(|v| Value::Struct(v)), + .map(Value::Struct), v => Err(format!("Value `{}` isn't allowed", v)), } } @@ -320,10 +326,13 @@ impl Into for CheckedValues { fn parse(s: &str) -> Result, Error> { let json_values: serde_json::Value = serde_json::from_str(s).map_err(|e| Error::Json(e.to_string()))?; - Values::try_from(json_values).map_err(|e| Error::Conversion(e)) + Values::try_from(json_values).map_err(Error::Conversion) } -pub fn parse_strict(s: &str, types: Vec) -> Result, Error> { +pub fn parse_strict( + s: &str, + types: Vec, +) -> Result, Error> { let parsed = parse(s)?; if parsed.0.len() != types.len() { return Err(Error::Type(format!( @@ -338,7 +347,7 @@ pub fn parse_strict(s: &str, types: Vec) -> Result, _>>() - .map_err(|e| Error::Type(e))?; + .map_err(Error::Type)?; Ok(CheckedValues(checked)) } @@ -403,14 +412,19 @@ mod tests { mod strict { use super::*; - use zokrates_core::typed_absy::types::{StructMember, StructType}; + use zokrates_core::typed_absy::types::{ + ConcreteStructMember, ConcreteStructType, ConcreteType, + }; #[test] fn fields() { let s = r#"["1", "2"]"#; assert_eq!( - parse_strict::(s, vec![Type::FieldElement, Type::FieldElement]) - .unwrap(), + parse_strict::( + s, + vec![ConcreteType::FieldElement, ConcreteType::FieldElement] + ) + .unwrap(), CheckedValues(vec![ CheckedValue::Field(1.into()), CheckedValue::Field(2.into()) @@ -422,7 +436,8 @@ mod tests { fn bools() { let s = "[true, false]"; assert_eq!( - parse_strict::(s, vec![Type::Boolean, Type::Boolean]).unwrap(), + parse_strict::(s, vec![ConcreteType::Boolean, ConcreteType::Boolean]) + .unwrap(), CheckedValues(vec![ CheckedValue::Boolean(true), CheckedValue::Boolean(false) @@ -434,7 +449,11 @@ mod tests { fn array() { let s = "[[true, false]]"; assert_eq!( - parse_strict::(s, vec![Type::array(Type::Boolean, 2)]).unwrap(), + parse_strict::( + s, + vec![ConcreteType::array((ConcreteType::Boolean, 2usize))] + ) + .unwrap(), CheckedValues(vec![CheckedValue::Array(vec![ CheckedValue::Boolean(true), CheckedValue::Boolean(false) @@ -448,10 +467,13 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(StructType::new( + vec![ConcreteType::Struct(ConcreteStructType::new( "".into(), "".into(), - vec![StructMember::new("a".into(), Type::FieldElement)] + vec![ConcreteStructMember::new( + "a".into(), + ConcreteType::FieldElement + )] ))] ) .unwrap(), @@ -466,10 +488,13 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(StructType::new( + vec![ConcreteType::Struct(ConcreteStructType::new( "".into(), "".into(), - vec![StructMember::new("a".into(), Type::FieldElement)] + vec![ConcreteStructMember::new( + "a".into(), + ConcreteType::FieldElement + )] ))] ) .unwrap_err(), @@ -480,10 +505,13 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(StructType::new( + vec![ConcreteType::Struct(ConcreteStructType::new( "".into(), "".into(), - vec![StructMember::new("a".into(), Type::FieldElement)] + vec![ConcreteStructMember::new( + "a".into(), + ConcreteType::FieldElement + )] ))] ) .unwrap_err(), @@ -494,10 +522,13 @@ mod tests { assert_eq!( parse_strict::( s, - vec![Type::Struct(StructType::new( + vec![ConcreteType::Struct(ConcreteStructType::new( "".into(), "".into(), - vec![StructMember::new("a".into(), Type::FieldElement)] + vec![ConcreteStructMember::new( + "a".into(), + ConcreteType::FieldElement + )] ))] ) .unwrap_err(), diff --git a/zokrates_book/src/SUMMARY.md b/zokrates_book/src/SUMMARY.md index 2a6406f3..6ccefb31 100644 --- a/zokrates_book/src/SUMMARY.md +++ b/zokrates_book/src/SUMMARY.md @@ -12,6 +12,7 @@ - [Control flow](language/control_flow.md) - [Imports](language/imports.md) - [Comments](language/comments.md) + - [Generics](language/generics.md) - [Macros](language/macros.md) - [Toolbox](toolbox/index.md) diff --git a/zokrates_book/src/language/control_flow.md b/zokrates_book/src/language/control_flow.md index c83961bf..10036b2b 100644 --- a/zokrates_book/src/language/control_flow.md +++ b/zokrates_book/src/language/control_flow.md @@ -12,6 +12,12 @@ Arguments are passed by value. {{#include ../../../zokrates_cli/examples/book/side_effects.zok}} ``` +Generic paramaters, if any, must be compile-time constants. They are inferred by the compiler if that is possible, but can also be provided explicitly. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/generic_call.zok}} +``` + ### If-expressions An if-expression allows you to branch your code depending on a boolean condition. @@ -28,7 +34,7 @@ For loops are available with the following syntax: {{#include ../../../zokrates_cli/examples/book/for.zok}} ``` -The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. +The bounds have to be constant at compile-time, therefore they cannot depend on execution inputs. They can depend on generic parameters. ### Assertions diff --git a/zokrates_book/src/language/functions.md b/zokrates_book/src/language/functions.md index 4fe27339..8282b68e 100644 --- a/zokrates_book/src/language/functions.md +++ b/zokrates_book/src/language/functions.md @@ -7,7 +7,14 @@ A function has to be declared at the top level before it is called. ``` A function's signature has to be explicitly provided. -Functions can return many values by providing them as a comma-separated list. + +A function can be generic over any number of values of type `u32`. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/generic_function_declaration.zok}} +``` + +Functions can return multiple values by providing them as a comma-separated list. ```zokrates {{#include ../../../zokrates_cli/examples/book/multi_return.zok}} diff --git a/zokrates_book/src/language/generics.md b/zokrates_book/src/language/generics.md new file mode 100644 index 00000000..e0e0dfe1 --- /dev/null +++ b/zokrates_book/src/language/generics.md @@ -0,0 +1,7 @@ +## Generics + +ZoKrates supports code that is generic over constants of the `u32` type. No specific keyword is used: the compiler determines if the generic parameters are indeed constant at compile time. Here's an example of generic code in ZoKrates: + +```zokrates +{{#include ../../../zokrates_cli/examples/book/generics.zok}} +``` \ No newline at end of file diff --git a/zokrates_book/src/language/types.md b/zokrates_book/src/language/types.md index 4a978cae..78d12125 100644 --- a/zokrates_book/src/language/types.md +++ b/zokrates_book/src/language/types.md @@ -8,7 +8,7 @@ ZoKrates currently exposes two primitive types and two complex types: This is the most basic type in ZoKrates, and it represents a field element with positive integer values in `[0, p - 1]` where `p` is a (large) prime number. Standard arithmetic operations are supported; note that [division in the finite field](https://en.wikipedia.org/wiki/Finite_field_arithmetic) behaves differently than in the case of integers. -As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/reference/proving_schemes.html#alt_bn128) curve supported by Ethereum. +As an example, `p` is set to `21888242871839275222246405745257275088548364400416034343698204186575808495617` when working with the [ALT_BN128](/toolbox/proving_schemes.html#alt_bn128) curve supported by Ethereum. While `field` values mostly behave like unsigned integers, one should keep in mind that they overflow at `p` and not some power of 2, so that we have: @@ -32,13 +32,23 @@ Similarly to booleans, unsigned integer inputs of the main function only accept The division operation calculates the standard floor division for integers. The `%` operand can be used to obtain the remainder. +### Numeric inference + +In the case of decimal literals like `42`, the compiler tries to find the appropriate type (`field`, `u8`, `u16` or `u32`) depending on the context. If it cannot converge to a single option, an error is returned. This means that there is no default type for decimal literals. + +All operations between literals have the semantics of the infered type. + +```zokrates +{{#include ../../../zokrates_cli/examples/book/numeric_inference.zok}} +``` + ## Complex Types ZoKrates provides two complex types: arrays and structs. ### Arrays -ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. +ZoKrates supports static arrays, i.e., whose length needs to be known at compile time. For more details on generic array sizes, see [constant generics](/language/generics.html) Arrays can contain elements of any type and have arbitrary dimensions. The following example code shows examples of how to use arrays: diff --git a/zokrates_cli/build.rs b/zokrates_cli/build.rs index add3c623..bbdc6b48 100644 --- a/zokrates_cli/build.rs +++ b/zokrates_cli/build.rs @@ -11,5 +11,5 @@ fn export_stdlib() { let out_dir = env::var("OUT_DIR").unwrap(); let mut options = CopyOptions::new(); options.overwrite = true; - copy_items(&vec!["tests/contract"], out_dir, &options).unwrap(); + copy_items(&["tests/contract"], out_dir, &options).unwrap(); } diff --git a/zokrates_cli/examples/arrays/array_loop.zok b/zokrates_cli/examples/arrays/array_loop.zok index 024d3777..73e4fb45 100644 --- a/zokrates_cli/examples/arrays/array_loop.zok +++ b/zokrates_cli/examples/arrays/array_loop.zok @@ -1,7 +1,7 @@ -def main() -> field: - field[3] a = [1, 2, 3] - field c = 0 - for field i in 0..3 do +def main() -> u32: + u32[3] a = [1, 2, 3] + u32 c = 0 + for u32 i in 0..3 do c = c + a[i] endfor return c \ No newline at end of file diff --git a/zokrates_cli/examples/arrays/array_loop_update.zok b/zokrates_cli/examples/arrays/array_loop_update.zok index 815de1c0..d16cc9da 100644 --- a/zokrates_cli/examples/arrays/array_loop_update.zok +++ b/zokrates_cli/examples/arrays/array_loop_update.zok @@ -1,7 +1,7 @@ -def main() -> (field[3]): - field[3] a = [1, 2, 3] - field[3] c = [4, 5, 6] - for field i in 0..3 do +def main() -> (u32[3]): + u32[3] a = [1, 2, 3] + u32[3] c = [4, 5, 6] + for u32 i in 0..3 do c[i] = c[i] + a[i] endfor return c \ No newline at end of file diff --git a/zokrates_cli/examples/arrays/boolean_array.zok b/zokrates_cli/examples/arrays/boolean_array.zok index 08633b67..89576f4d 100644 --- a/zokrates_cli/examples/arrays/boolean_array.zok +++ b/zokrates_cli/examples/arrays/boolean_array.zok @@ -3,7 +3,7 @@ def main(bool[3] a) -> (field[3]): a[1] = true || a[2] a[2] = a[0] field[3] result = [0; 3] - for field i in 0..3 do + for u32 i in 0..3 do result[i] = if a[i] then 33 else 0 fi endfor return result diff --git a/zokrates_cli/examples/arrays/cube.zok b/zokrates_cli/examples/arrays/cube.zok index 1ea112d0..eca27ec6 100644 --- a/zokrates_cli/examples/arrays/cube.zok +++ b/zokrates_cli/examples/arrays/cube.zok @@ -1,9 +1,9 @@ def main(field[2][2][2] cube) -> field: field res = 0 - for field i in 0..2 do - for field j in 0..2 do - for field k in 0..2 do + for u32 i in 0..2 do + for u32 j in 0..2 do + for u32 k in 0..2 do res = res + cube[i][j][k] endfor endfor diff --git a/zokrates_cli/examples/arrays/lookup.zok b/zokrates_cli/examples/arrays/lookup.zok index 95404692..6b25d439 100644 --- a/zokrates_cli/examples/arrays/lookup.zok +++ b/zokrates_cli/examples/arrays/lookup.zok @@ -1,2 +1,2 @@ -def main(field index, field[5] array) -> field: +def main(u32 index, field[5] array) -> field: return array[index] \ No newline at end of file diff --git a/zokrates_cli/examples/arrays/multidim_update.zok b/zokrates_cli/examples/arrays/multidim_update.zok index 89bdb22d..358c049a 100644 --- a/zokrates_cli/examples/arrays/multidim_update.zok +++ b/zokrates_cli/examples/arrays/multidim_update.zok @@ -1,4 +1,4 @@ -def main(field[10][10][10] a, field i, field j, field k) -> (field[3]): +def main(field[10][10][10] a, u32 i, u32 j, u32 k) -> (field[3]): a[i][j][k] = 42 field[3][3] b = [[1, 2, 3], [1, 2, 3], [1, 2, 3]] return b[0] \ No newline at end of file diff --git a/zokrates_cli/examples/arrays/repeat.zok b/zokrates_cli/examples/arrays/repeat.zok new file mode 100644 index 00000000..07ca853e --- /dev/null +++ b/zokrates_cli/examples/arrays/repeat.zok @@ -0,0 +1,4 @@ +def main(field a) -> field[4]: + u32 SIZE = 4 + field[SIZE] res = [a; SIZE] + return res \ No newline at end of file diff --git a/zokrates_cli/examples/arrays/slicefrom.zok b/zokrates_cli/examples/arrays/slicefrom.zok index a3e09ec7..225d4a6f 100644 --- a/zokrates_cli/examples/arrays/slicefrom.zok +++ b/zokrates_cli/examples/arrays/slicefrom.zok @@ -1,6 +1,6 @@ -def slice32from(field offset, field[2048] input) -> (field[32]): +def slice32from(u32 offset, field[2048] input) -> (field[32]): field[32] result = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - for field i in 0..32 do + for u32 i in 0..32 do result[i] = input[offset + i] endfor return result diff --git a/zokrates_cli/examples/arrays/wrap_select.zok b/zokrates_cli/examples/arrays/wrap_select.zok index 162b6e1f..03c8b545 100644 --- a/zokrates_cli/examples/arrays/wrap_select.zok +++ b/zokrates_cli/examples/arrays/wrap_select.zok @@ -1,4 +1,4 @@ -def get(field[32] array, field index) -> field: +def get(field[32] array, u32 index) -> field: return array[index] def main() -> field: diff --git a/zokrates_cli/examples/book/array.zok b/zokrates_cli/examples/book/array.zok index 943fb12a..1e8bcfce 100644 --- a/zokrates_cli/examples/book/array.zok +++ b/zokrates_cli/examples/book/array.zok @@ -5,4 +5,6 @@ def main() -> field: field[4] c = [...a, 4] // initialize an array copying values from `a`, followed by 4 field[2] d = a[1..3] // initialize an array copying a slice from `a` bool[3] e = [true, true || false, true] // initialize a boolean array + u32 SIZE = 3 + field[SIZE] f = [1, 2, 3] // initialize a field array with a size that's a compile-time constant return a[0] + b[1] + c[2] diff --git a/zokrates_cli/examples/book/assert.zok b/zokrates_cli/examples/book/assert.zok index c4277fed..75db2763 100644 --- a/zokrates_cli/examples/book/assert.zok +++ b/zokrates_cli/examples/book/assert.zok @@ -1,3 +1,3 @@ def main() -> (): - assert(1 == 2) + assert(1f == 2f) return \ No newline at end of file diff --git a/zokrates_cli/examples/book/for.zok b/zokrates_cli/examples/book/for.zok index 8c61fc29..b5d8f502 100644 --- a/zokrates_cli/examples/book/for.zok +++ b/zokrates_cli/examples/book/for.zok @@ -1,7 +1,7 @@ -def main() -> field: - field res = 0 - for field i in 0..4 do - for field j in i..5 do +def main() -> u32: + u32 res = 0 + for u32 i in 0..4 do + for u32 j in i..5 do res = res + i endfor endfor diff --git a/zokrates_cli/examples/book/for_scope.zok b/zokrates_cli/examples/book/for_scope.zok index 340ca6c9..2b92dd1f 100644 --- a/zokrates_cli/examples/book/for_scope.zok +++ b/zokrates_cli/examples/book/for_scope.zok @@ -1,6 +1,6 @@ -def main() -> field: - field a = 0 - for field i in 0..5 do +def main() -> u32: + u32 a = 0 + for u32 i in 0..5 do a = a + i endfor // return i <- not allowed diff --git a/zokrates_cli/examples/book/function_declaration.zok b/zokrates_cli/examples/book/function_declaration.zok index 89658ac0..2e0f1bb4 100644 --- a/zokrates_cli/examples/book/function_declaration.zok +++ b/zokrates_cli/examples/book/function_declaration.zok @@ -1,5 +1,5 @@ -def foo() -> field: - return 1 +def foo(field a, field b) -> field: + return a + b def main() -> field: - return foo() \ No newline at end of file + return foo(1, 2) \ No newline at end of file diff --git a/zokrates_cli/examples/book/generic_call.zok b/zokrates_cli/examples/book/generic_call.zok new file mode 100644 index 00000000..a26a9f81 --- /dev/null +++ b/zokrates_cli/examples/book/generic_call.zok @@ -0,0 +1,7 @@ +def foo() -> field[P]: + return [42; P] + +def main() -> field[2]: + // `P` is inferred from the declaration of `res`, while `N` is provided explicitly + field[2] res = foo::<3, _>() + return res \ No newline at end of file diff --git a/zokrates_cli/examples/book/generic_function_declaration.zok b/zokrates_cli/examples/book/generic_function_declaration.zok new file mode 100644 index 00000000..07c77353 --- /dev/null +++ b/zokrates_cli/examples/book/generic_function_declaration.zok @@ -0,0 +1,6 @@ +def foo() -> field[N]: + return [42; N] + +def main() -> field[2]: + field[2] res = foo() + return res \ No newline at end of file diff --git a/zokrates_cli/examples/book/generics.zok b/zokrates_cli/examples/book/generics.zok new file mode 100644 index 00000000..a9f11321 --- /dev/null +++ b/zokrates_cli/examples/book/generics.zok @@ -0,0 +1,9 @@ +def sum(field[N] a) -> field: + field res = 0 + for u32 i in 0..N do + res = res + a[i] + endfor + return res + +def main(field[3] a) -> field: + return sum(a) \ No newline at end of file diff --git a/zokrates_cli/examples/book/no_shadowing.zok b/zokrates_cli/examples/book/no_shadowing.zok index 2708ee0b..be221372 100644 --- a/zokrates_cli/examples/book/no_shadowing.zok +++ b/zokrates_cli/examples/book/no_shadowing.zok @@ -1,7 +1,7 @@ def main() -> field: field a = 2 // field a = 3 <- not allowed - for field i in 0..5 do + for u32 i in 0..5 do // field a = 7 <- not allowed endfor return a \ No newline at end of file diff --git a/zokrates_cli/examples/book/numeric_inference.zok b/zokrates_cli/examples/book/numeric_inference.zok new file mode 100644 index 00000000..068229e0 --- /dev/null +++ b/zokrates_cli/examples/book/numeric_inference.zok @@ -0,0 +1,9 @@ +def main(): + // `255` is infered to `255f`, and the addition happens between field elements + assert(255 + 1f == 256) + + // `255` is infered to `255u8`, and the addition happens between u8 + // This causes an overflow + assert(255 + 1u8 == 0) + + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/embed_type_mismatch.zok b/zokrates_cli/examples/compile_errors/embed_type_mismatch.zok new file mode 100644 index 00000000..047df566 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/embed_type_mismatch.zok @@ -0,0 +1,5 @@ +import "EMBED/u8_to_bits" as u8_to_bits + +def main(u8 x): + bool[32] b = u8_to_bits(x) // note the incorrect array length on the left + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/assign_size_mismatch.zok b/zokrates_cli/examples/compile_errors/generics/assign_size_mismatch.zok new file mode 100644 index 00000000..727346f6 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/assign_size_mismatch.zok @@ -0,0 +1,8 @@ +def foo(field[N] a) -> bool: + field[3] b = a + return true + + +def main(field[1] a): + assert(foo(a)) + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/concrete_length_mismatch.zok b/zokrates_cli/examples/compile_errors/generics/concrete_length_mismatch.zok new file mode 100644 index 00000000..906326a4 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/concrete_length_mismatch.zok @@ -0,0 +1,3 @@ +def main(): + assert([1] == [1, 2]) + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/generics_in_main.zok b/zokrates_cli/examples/compile_errors/generics/generics_in_main.zok new file mode 100644 index 00000000..137f0b61 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/generics_in_main.zok @@ -0,0 +1,2 @@ +def main

(field[P] a): + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/incompatible.zok b/zokrates_cli/examples/compile_errors/generics/incompatible.zok new file mode 100644 index 00000000..bbf114d7 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/incompatible.zok @@ -0,0 +1,5 @@ +def foo

(field[P] a, field[P] b) -> field: + return 42 + +def main() -> field: + return foo([1, 2], [1]) \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/no_weak_eq.zok b/zokrates_cli/examples/compile_errors/generics/no_weak_eq.zok new file mode 100644 index 00000000..5169eb5b --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/no_weak_eq.zok @@ -0,0 +1,3 @@ +def main(): + assert([[1]] == [1, 2]) + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/generics/unused.zok b/zokrates_cli/examples/compile_errors/generics/unused.zok new file mode 100644 index 00000000..025c7452 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/generics/unused.zok @@ -0,0 +1,2 @@ +def main

(): + return \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/out_of_for_scope.zok b/zokrates_cli/examples/compile_errors/out_of_for_scope.zok index 5d701a18..e1c73da1 100644 --- a/zokrates_cli/examples/compile_errors/out_of_for_scope.zok +++ b/zokrates_cli/examples/compile_errors/out_of_for_scope.zok @@ -1,4 +1,4 @@ def main() -> field: - for field i in 0..5 do + for u32 i in 0..5 do endfor return i \ No newline at end of file diff --git a/zokrates_cli/examples/conditions.zok b/zokrates_cli/examples/conditions.zok index fd92d11d..74cd9b9a 100644 --- a/zokrates_cli/examples/conditions.zok +++ b/zokrates_cli/examples/conditions.zok @@ -14,6 +14,6 @@ def main(field a) -> field: assert(2 * b == a * 12 + 60) field c = 7 * (b + a) assert(isEqual(c, 7 * b + 7 * a)) - field k = if [1, 2] == [3, 4] then 1 else 3 fi + field k = if [1f, 2] == [3f, 4] then 1 else 3 fi assert([Bar { foo : [Foo { a: 42 }]}] == [Bar { foo : [Foo { a: 42 }]}]) return b + c \ No newline at end of file diff --git a/zokrates_cli/examples/for.zok b/zokrates_cli/examples/for.zok index 6cce5135..7bdf2b8f 100644 --- a/zokrates_cli/examples/for.zok +++ b/zokrates_cli/examples/for.zok @@ -1,6 +1,10 @@ +def bound(field x) -> u32: + return 41 + 1 + def main(field a) -> field: field x = 7 - for field i in 0..10 do + x = x + 1 + for u32 i in 0..bound(x) do // x = x + a x = x + a endfor diff --git a/zokrates_cli/examples/functions/lt_comparison.zok b/zokrates_cli/examples/functions/lt_comparison.zok index 8e9b3782..c6542840 100644 --- a/zokrates_cli/examples/functions/lt_comparison.zok +++ b/zokrates_cli/examples/functions/lt_comparison.zok @@ -4,14 +4,14 @@ def lt(field a,field b) -> bool: def cutoff() -> field: return 31337 -def getThing(field index) -> field: +def getThing(u32 index) -> field: field[6] a = [13, 23, 43, 53, 73, 83] return a[index] def cubeThing(field thing) -> field: return thing**3 -def main(field index) -> bool: +def main(u32 index) -> bool: field thing = getThing(index) thing = cubeThing(thing) return lt(cutoff(), thing) diff --git a/zokrates_cli/examples/merkleTree/pedersenPathProof3.zok b/zokrates_cli/examples/merkleTree/pedersenPathProof3.zok index d6225098..9a7114be 100644 --- a/zokrates_cli/examples/merkleTree/pedersenPathProof3.zok +++ b/zokrates_cli/examples/merkleTree/pedersenPathProof3.zok @@ -4,9 +4,6 @@ import "ecc/babyjubjubParams" as context from "ecc/babyjubjubParams" import BabyJubJubParams import "hashes/utils/256bitsDirectionHelper" as multiplex -def multiplex(bool selector, u32[8] left, u32[8] right) -> (u32[8]): - return if selector then right else left fi - // Merke-Tree inclusion proof for tree depth 3 using SNARK efficient pedersen hashes // directionSelector=> true if current digest is on the rhs of the hash diff --git a/zokrates_cli/examples/n_choose_k.zok b/zokrates_cli/examples/n_choose_k.zok index a7a34d94..b5e820d7 100644 --- a/zokrates_cli/examples/n_choose_k.zok +++ b/zokrates_cli/examples/n_choose_k.zok @@ -1,9 +1,11 @@ +import "utils/casts/u32_to_field" as to_field + // Binomial Coeffizient, n!/(k!*(n-k)!). def fac(field x) -> field: field f = 1 field counter = 0 - for field i in 1..100 do - f = if counter == x then f else f * i fi + for u32 i in 1..100 do + f = if counter == x then f else f * to_field(i) fi counter = if counter == x then counter else counter + 1 fi endfor return f diff --git a/zokrates_cli/examples/propagate.zok b/zokrates_cli/examples/propagate.zok index 9b3176b0..543ac20c 100644 --- a/zokrates_cli/examples/propagate.zok +++ b/zokrates_cli/examples/propagate.zok @@ -2,7 +2,7 @@ def main() -> field: field a = 1 + 2 + 3 field b = if 1 < a then 3 else a + 3 fi field c = if b + a == 2 then 1 else b fi - for field e in 0..2 do + for u32 e in 0..2 do field g = 4 c = c + g endfor diff --git a/zokrates_cli/examples/reduceable_exponent.zok b/zokrates_cli/examples/reduceable_exponent.zok index d3c38a74..b17bb1c0 100644 --- a/zokrates_cli/examples/reduceable_exponent.zok +++ b/zokrates_cli/examples/reduceable_exponent.zok @@ -1,3 +1,3 @@ def main() -> field: - field a = 2 - return 2**(a**2 + 2) \ No newline at end of file + u32 a = 2 + return 2**(a * 2 + 2) \ No newline at end of file diff --git a/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok b/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok index e61db31d..a5cd37d6 100644 --- a/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/prime_sudoku_checker.zok @@ -32,26 +32,26 @@ def main(field a21, field b11, field b22, field c11, field c22, field d21, priva bool res = true // go through the whole grid and check that all elements are valid - for field i in 0..4 do - for field j in 0..4 do + for u32 i in 0..4 do + for u32 j in 0..4 do res = res && validateInput(a[i][j]) endfor endfor // go through the 4 2x2 boxes and check that they do not contain duplicates - for field i in 0..1 do - for field j in 0..1 do + for u32 i in 0..1 do + for u32 j in 0..1 do res = res && checkNoDuplicates(a[2*i][2*i], a[2*i][2*i + 1], a[2*i + 1][2*i], a[2*i + 1][2*i + 1]) endfor endfor // go through the 4 rows and check that they do not contain duplicates - for field i in 0..4 do + for u32 i in 0..4 do res = res && checkNoDuplicates(a[i][0], a[i][1], a[i][2], a[i][3]) endfor // go through the 4 columns and check that they do not contain duplicates - for field j in 0..4 do + for u32 j in 0..4 do res = res && checkNoDuplicates(a[0][j], a[1][j], a[2][j], a[3][j]) endfor diff --git a/zokrates_cli/examples/waldo.zok b/zokrates_cli/examples/waldo.zok index 1839354f..436a42b3 100644 --- a/zokrates_cli/examples/waldo.zok +++ b/zokrates_cli/examples/waldo.zok @@ -10,6 +10,6 @@ def isWaldo(field a, field p, field q) -> bool: return a == p * q // define all -def main(field[3] a, private field index, private field p, private field q) -> bool: +def main(field[3] a, private u32 index, private field p, private field q) -> bool: // prover provides the index of Waldo return isWaldo(a[index], p, q) \ No newline at end of file diff --git a/zokrates_cli/src/bin.rs b/zokrates_cli/src/bin.rs index 1c0c9095..68ad5687 100644 --- a/zokrates_cli/src/bin.rs +++ b/zokrates_cli/src/bin.rs @@ -32,10 +32,13 @@ fn cli() -> Result<(), String> { compile::subcommand(), check::subcommand(), compute_witness::subcommand(), + #[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] setup::subcommand(), export_verifier::subcommand(), + #[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] generate_proof::subcommand(), print_proof::subcommand(), + #[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] verify::subcommand()]) .get_matches(); @@ -140,7 +143,7 @@ mod tests { let interpreter = ir::Interpreter::default(); let _ = interpreter - .execute(&artifacts.prog(), &vec![Bn128Field::from(0)]) + .execute(&artifacts.prog(), &[Bn128Field::from(0)]) .unwrap(); } } @@ -169,7 +172,7 @@ mod tests { let interpreter = ir::Interpreter::default(); - let res = interpreter.execute(&artifacts.prog(), &vec![Bn128Field::from(0)]); + let res = interpreter.execute(&artifacts.prog(), &[Bn128Field::from(0)]); assert!(res.is_err()); } diff --git a/zokrates_cli/src/constants.rs b/zokrates_cli/src/constants.rs index 1af664ab..2720f85f 100644 --- a/zokrates_cli/src/constants.rs +++ b/zokrates_cli/src/constants.rs @@ -19,6 +19,7 @@ lazy_static! { .unwrap(); } +#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") { if cfg!(feature = "ark") { if cfg!(feature = "bellman") { @@ -26,27 +27,21 @@ pub const BACKENDS: &[&str] = if cfg!(feature = "libsnark") { } else { &[LIBSNARK, ARK] } + } else if cfg!(feature = "bellman") { + &[BELLMAN, LIBSNARK] } else { - if cfg!(feature = "bellman") { - &[BELLMAN, LIBSNARK] - } else { - &[LIBSNARK] - } + &[LIBSNARK] } +} else if cfg!(feature = "ark") { + if cfg!(feature = "bellman") { + &[BELLMAN, ARK] + } else { + &[ARK] + } +} else if cfg!(feature = "bellman") { + &[BELLMAN] } else { - if cfg!(feature = "ark") { - if cfg!(feature = "bellman") { - &[BELLMAN, ARK] - } else { - &[ARK] - } - } else { - if cfg!(feature = "bellman") { - &[BELLMAN] - } else { - &[] - } - } + &[] }; pub const BN128: &str = "bn128"; diff --git a/zokrates_cli/src/ops/check.rs b/zokrates_cli/src/ops/check.rs index 99186b86..f74c6a7c 100644 --- a/zokrates_cli/src/ops/check.rs +++ b/zokrates_cli/src/ops/check.rs @@ -69,7 +69,7 @@ fn cli_check(sub_matches: &ArgMatches) -> Result<(), String> { format!( "{}:{}", file.strip_prefix(std::env::current_dir().unwrap()) - .unwrap_or(file.as_path()) + .unwrap_or_else(|_| file.as_path()) .display(), e.value() ) diff --git a/zokrates_cli/src/ops/compile.rs b/zokrates_cli/src/ops/compile.rs index aeea574d..a66e468f 100644 --- a/zokrates_cli/src/ops/compile.rs +++ b/zokrates_cli/src/ops/compile.rs @@ -92,9 +92,9 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { let fmt_error = |e: &CompileError| { let file = e.file().canonicalize().unwrap(); format!( - "{}:{}", + "{}: {}", file.strip_prefix(std::env::current_dir().unwrap()) - .unwrap_or(file.as_path()) + .unwrap_or_else(|_| file.as_path()) .display(), e.value() ) @@ -154,7 +154,7 @@ fn cli_compile(sub_matches: &ArgMatches) -> Result<(), String> { .map_err(|why| format!("Couldn't create {}: {}", hr_output_path.display(), why))?; let mut hrofb = BufWriter::new(hr_output_file); - write!(&mut hrofb, "{}\n", program_flattened) + writeln!(&mut hrofb, "{}", program_flattened) .map_err(|_| "Unable to write data to file".to_string())?; hrofb .flush() diff --git a/zokrates_cli/src/ops/compute_witness.rs b/zokrates_cli/src/ops/compute_witness.rs index f1657f2d..1162e5bc 100644 --- a/zokrates_cli/src/ops/compute_witness.rs +++ b/zokrates_cli/src/ops/compute_witness.rs @@ -8,7 +8,7 @@ use zokrates_abi::Encode; use zokrates_core::ir; use zokrates_core::ir::ProgEnum; use zokrates_core::typed_absy::abi::Abi; -use zokrates_core::typed_absy::{Signature, Type}; +use zokrates_core::typed_absy::types::{ConcreteSignature, ConcreteType}; use zokrates_field::Field; pub fn subcommand() -> App<'static, 'static> { @@ -106,9 +106,12 @@ fn cli_compute(ir_prog: ir::Prog, sub_matches: &ArgMatches) -> Resu abi.signature() } - false => Signature::new() - .inputs(vec![Type::FieldElement; ir_prog.main.arguments.len()]) - .outputs(vec![Type::FieldElement; ir_prog.main.returns.len()]), + false => ConcreteSignature::new() + .inputs(vec![ + ConcreteType::FieldElement; + ir_prog.main.arguments.len() + ]) + .outputs(vec![ConcreteType::FieldElement; ir_prog.main.returns.len()]), }; use zokrates_abi::Inputs; @@ -123,8 +126,8 @@ fn cli_compute(ir_prog: ir::Prog, sub_matches: &ArgMatches) -> Resu a.map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string())) .collect::, _>>() }) - .unwrap_or(Ok(vec![])) - .map(|v| Inputs::Raw(v)) + .unwrap_or_else(|| Ok(vec![])) + .map(Inputs::Raw) } // take stdin arguments true => { @@ -137,7 +140,7 @@ fn cli_compute(ir_prog: ir::Prog, sub_matches: &ArgMatches) -> Resu use zokrates_abi::parse_strict; parse_strict(&input, signature.inputs) - .map(|parsed| Inputs::Abi(parsed)) + .map(Inputs::Abi) .map_err(|why| why.to_string()) } Err(_) => Err(String::from("???")), @@ -148,10 +151,10 @@ fn cli_compute(ir_prog: ir::Prog, sub_matches: &ArgMatches) -> Resu Ok(_) => { input.retain(|x| x != '\n'); input - .split(" ") + .split(' ') .map(|x| T::try_from_dec_str(x).map_err(|_| x.to_string())) .collect::, _>>() - .map(|v| Inputs::Raw(v)) + .map(Inputs::Raw) } Err(_) => Err(String::from("???")), }, diff --git a/zokrates_cli/src/ops/generate_proof.rs b/zokrates_cli/src/ops/generate_proof.rs index 3118a7d0..7156518c 100644 --- a/zokrates_cli/src/ops/generate_proof.rs +++ b/zokrates_cli/src/ops/generate_proof.rs @@ -174,7 +174,7 @@ fn cli_generate_proof, B: Backend>( .write(proof.as_bytes()) .map_err(|why| format!("Couldn't write to {}: {}", proof_path.display(), why))?; - println!("Proof:\n{}", format!("{}", proof)); + println!("Proof:\n{}", proof); Ok(()) } diff --git a/zokrates_cli/src/ops/mod.rs b/zokrates_cli/src/ops/mod.rs index bb1748c8..22c53a38 100644 --- a/zokrates_cli/src/ops/mod.rs +++ b/zokrates_cli/src/ops/mod.rs @@ -1,7 +1,6 @@ pub mod check; pub mod compile; pub mod compute_witness; -#[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] pub mod export_verifier; #[cfg(any(feature = "bellman", feature = "ark", feature = "libsnark"))] pub mod generate_proof; diff --git a/zokrates_cli/tests/code/n_choose_k.zok b/zokrates_cli/tests/code/n_choose_k.zok index a7a34d94..b5e820d7 100644 --- a/zokrates_cli/tests/code/n_choose_k.zok +++ b/zokrates_cli/tests/code/n_choose_k.zok @@ -1,9 +1,11 @@ +import "utils/casts/u32_to_field" as to_field + // Binomial Coeffizient, n!/(k!*(n-k)!). def fac(field x) -> field: field f = 1 field counter = 0 - for field i in 1..100 do - f = if counter == x then f else f * i fi + for u32 i in 1..100 do + f = if counter == x then f else f * to_field(i) fi counter = if counter == x then counter else counter + 1 fi endfor return f diff --git a/zokrates_cli/tests/integration.rs b/zokrates_cli/tests/integration.rs index b147d3da..0a666cf6 100644 --- a/zokrates_cli/tests/integration.rs +++ b/zokrates_cli/tests/integration.rs @@ -3,7 +3,7 @@ extern crate serde_json; #[cfg(test)] mod integration { - use assert_cli; + use serde_json::from_reader; use std::fs; use std::fs::File; @@ -147,11 +147,11 @@ mod integration { .map_err(|why| why.to_string()) .unwrap(); - let signature = abi.signature().clone(); + let signature = abi.signature(); let inputs_abi: zokrates_abi::Inputs = parse_strict(&json_input_str, signature.inputs) - .map(|parsed| zokrates_abi::Inputs::Abi(parsed)) + .map(zokrates_abi::Inputs::Abi) .map_err(|why| why.to_string()) .unwrap(); let inputs_raw: Vec<_> = inputs_abi @@ -169,7 +169,7 @@ mod integration { inline_witness_path.to_str().unwrap(), ]; - if inputs_raw.len() > 0 { + if !inputs_raw.is_empty() { compute_inline.push("-a"); for arg in &inputs_raw { @@ -202,7 +202,7 @@ mod integration { assert_eq!(inline_witness, witness); - for line in expected_witness.as_str().split("\n") { + for line in expected_witness.as_str().split('\n') { assert!( witness.contains(line), "Witness generation failed for {}\n\nLine \"{}\" not found in witness", diff --git a/zokrates_core/src/absy/from_ast.rs b/zokrates_core/src/absy/from_ast.rs index 4d19159d..1ec78a8e 100644 --- a/zokrates_core/src/absy/from_ast.rs +++ b/zokrates_core/src/absy/from_ast.rs @@ -1,7 +1,6 @@ use crate::absy; use crate::imports; -use num::ToPrimitive; use num_bigint::BigUint; use zokrates_pest_ast as pest; @@ -10,14 +9,14 @@ impl<'ast> From> for absy::Module<'ast> { absy::Module::with_symbols( prog.structs .into_iter() - .map(|t| absy::SymbolDeclarationNode::from(t)) + .map(absy::SymbolDeclarationNode::from) .chain( prog.functions .into_iter() - .map(|f| absy::SymbolDeclarationNode::from(f)), + .map(absy::SymbolDeclarationNode::from), ), ) - .imports(prog.imports.into_iter().map(|i| absy::ImportNode::from(i))) + .imports(prog.imports.into_iter().map(absy::ImportNode::from)) } } @@ -31,17 +30,16 @@ impl<'ast> From> for absy::ImportNode<'ast> { .alias(import.alias.map(|a| a.span.as_str())) .span(import.span) } - pest::ImportDirective::From(import) => imports::Import::new( - Some(import.symbol.span.as_str()), - std::path::Path::new(import.source.span.as_str()), - ) - .alias( - import - .alias - .map(|a| a.span.as_str()) - .or(Some(import.symbol.span.as_str())), - ) - .span(import.span), + pest::ImportDirective::From(import) => { + let symbol_str = import.symbol.span.as_str(); + + imports::Import::new( + Some(import.symbol.span.as_str()), + std::path::Path::new(import.source.span.as_str()), + ) + .alias(import.alias.map(|a| a.span.as_str()).or(Some(symbol_str))) + .span(import.span) + } } } } @@ -58,7 +56,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'a fields: definition .fields .into_iter() - .map(|f| absy::StructDefinitionFieldNode::from(f)) + .map(absy::StructDefinitionFieldNode::from) .collect(), } .span(span.clone()); @@ -72,7 +70,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'a } impl<'ast> From> for absy::StructDefinitionFieldNode<'ast> { - fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode { + fn from(field: pest::StructField<'ast>) -> absy::StructDefinitionFieldNode<'ast> { use crate::absy::NodeValue; let span = field.span; @@ -92,6 +90,13 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { let span = function.span; let signature = absy::UnresolvedSignature::new() + .generics( + function + .generics + .into_iter() + .map(absy::ConstantGenericNode::from) + .collect(), + ) .inputs( function .parameters @@ -105,7 +110,7 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { .returns .clone() .into_iter() - .map(|r| absy::UnresolvedTypeNode::from(r)) + .map(absy::UnresolvedTypeNode::from) .collect(), ); @@ -115,12 +120,12 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { arguments: function .parameters .into_iter() - .map(|a| absy::ParameterNode::from(a)) + .map(absy::ParameterNode::from) .collect(), statements: function .statements .into_iter() - .flat_map(|s| statements_from_statement(s)) + .flat_map(statements_from_statement) .collect(), signature, } @@ -134,6 +139,16 @@ impl<'ast> From> for absy::SymbolDeclarationNode<'ast> { } } +impl<'ast> From> for absy::ConstantGenericNode<'ast> { + fn from(g: pest::IdentifierExpression<'ast>) -> absy::ConstantGenericNode<'ast> { + use absy::NodeValue; + + let name = g.span.as_str(); + + name.span(g.span) + } +} + impl<'ast> From> for absy::ParameterNode<'ast> { fn from(param: pest::Parameter<'ast>) -> absy::ParameterNode<'ast> { use crate::absy::NodeValue; @@ -247,7 +262,7 @@ impl<'ast> From> for absy::StatementNode<'ast> { expressions: statement .expressions .into_iter() - .map(|e| absy::ExpressionNode::from(e)) + .map(absy::ExpressionNode::from) .collect(), } .span(statement.span.clone()), @@ -275,7 +290,7 @@ impl<'ast> From> for absy::StatementNode<'ast> { let statements: Vec> = statement .statements .into_iter() - .flat_map(|s| statements_from_statement(s)) + .flat_map(statements_from_statement) .collect(); let var = absy::Variable::new(index, ty).span(statement.index.span); @@ -289,7 +304,7 @@ impl<'ast> From> for absy::ExpressionNode<'ast> { match expression { pest::Expression::Binary(e) => absy::ExpressionNode::from(e), pest::Expression::Ternary(e) => absy::ExpressionNode::from(e), - pest::Expression::Constant(e) => absy::ExpressionNode::from(e), + pest::Expression::Literal(e) => absy::ExpressionNode::from(e), pest::Expression::Identifier(e) => absy::ExpressionNode::from(e), pest::Expression::Postfix(e) => absy::ExpressionNode::from(e), pest::Expression::InlineArray(e) => absy::ExpressionNode::from(e), @@ -458,7 +473,7 @@ impl<'ast> From> for absy::ExpressionNode<'ast array .expressions .into_iter() - .map(|e| absy::SpreadOrExpression::from(e)) + .map(absy::SpreadOrExpression::from) .collect(), ) .span(array.span) @@ -489,13 +504,8 @@ impl<'ast> From> for absy::ExpressionNode use crate::absy::NodeValue; let value = absy::ExpressionNode::from(*initializer.value); - let count: absy::ExpressionNode<'ast> = absy::ExpressionNode::from(initializer.count); - let count = match count.value { - absy::Expression::FieldConstant(v) => v.to_usize().unwrap(), - _ => unreachable!(), - }; - absy::Expression::InlineArray(vec![absy::SpreadOrExpression::Expression(value); count]) - .span(initializer.span) + let count = absy::ExpressionNode::from(*initializer.count); + absy::Expression::ArrayInitializer(box value, box count).span(initializer.span) } } @@ -529,9 +539,25 @@ impl<'ast> From> for absy::ExpressionNode<'ast> { pest::Access::Call(a) => match acc.value { absy::Expression::Identifier(_) => absy::Expression::FunctionCall( &id_str, - a.expressions + a.explicit_generics.map(|explicit_generics| { + explicit_generics + .values + .into_iter() + .map(|i| match i { + pest::ConstantGenericValue::Underscore(_) => None, + pest::ConstantGenericValue::Value(v) => { + Some(absy::ExpressionNode::from(v)) + } + pest::ConstantGenericValue::Identifier(i) => { + Some(absy::Expression::Identifier(i.span.as_str()).span(i.span)) + } + }) + .collect() + }), + a.arguments + .expressions .into_iter() - .map(|e| absy::ExpressionNode::from(e)) + .map(absy::ExpressionNode::from) .collect(), ), e => unimplemented!("only identifiers are callable, found \"{}\"", e), @@ -548,29 +574,63 @@ impl<'ast> From> for absy::ExpressionNode<'ast> { } } -impl<'ast> From> for absy::ExpressionNode<'ast> { - fn from(expression: pest::ConstantExpression<'ast>) -> absy::ExpressionNode<'ast> { +impl<'ast> From> for absy::ExpressionNode<'ast> { + fn from(expression: pest::DecimalLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> { use crate::absy::NodeValue; + + match expression.suffix { + Some(suffix) => match suffix { + pest::DecimalSuffix::Field(_) => absy::Expression::FieldConstant( + BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(), + ), + pest::DecimalSuffix::U32(_) => absy::Expression::U32Constant( + u32::from_str_radix(&expression.value.span.as_str(), 10).unwrap(), + ), + pest::DecimalSuffix::U16(_) => absy::Expression::U16Constant( + u16::from_str_radix(&expression.value.span.as_str(), 10).unwrap(), + ), + pest::DecimalSuffix::U8(_) => absy::Expression::U8Constant( + u8::from_str_radix(&expression.value.span.as_str(), 10).unwrap(), + ), + } + .span(expression.span), + None => absy::Expression::IntConstant( + BigUint::parse_bytes(&expression.value.span.as_str().as_bytes(), 10).unwrap(), + ) + .span(expression.span), + } + } +} + +impl<'ast> From> for absy::ExpressionNode<'ast> { + fn from(expression: pest::HexLiteralExpression<'ast>) -> absy::ExpressionNode<'ast> { + use crate::absy::NodeValue; + + match expression.value { + pest::HexNumberExpression::U32(e) => { + absy::Expression::U32Constant(u32::from_str_radix(&e.span.as_str(), 16).unwrap()) + } + pest::HexNumberExpression::U16(e) => { + absy::Expression::U16Constant(u16::from_str_radix(&e.span.as_str(), 16).unwrap()) + } + pest::HexNumberExpression::U8(e) => { + absy::Expression::U8Constant(u8::from_str_radix(&e.span.as_str(), 16).unwrap()) + } + } + .span(expression.span) + } +} + +impl<'ast> From> for absy::ExpressionNode<'ast> { + fn from(expression: pest::LiteralExpression<'ast>) -> absy::ExpressionNode<'ast> { + use crate::absy::NodeValue; + match expression { - pest::ConstantExpression::BooleanLiteral(c) => { + pest::LiteralExpression::BooleanLiteral(c) => { absy::Expression::BooleanConstant(c.value.parse().unwrap()).span(c.span) } - pest::ConstantExpression::DecimalNumber(n) => absy::Expression::FieldConstant( - BigUint::parse_bytes(&n.value.as_bytes(), 10).unwrap(), - ) - .span(n.span), - pest::ConstantExpression::U8(n) => absy::Expression::U8Constant( - u8::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(), - ) - .span(n.span), - pest::ConstantExpression::U16(n) => absy::Expression::U16Constant( - u16::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(), - ) - .span(n.span), - pest::ConstantExpression::U32(n) => absy::Expression::U32Constant( - u32::from_str_radix(&n.value.trim_start_matches("0x"), 16).unwrap(), - ) - .span(n.span), + pest::LiteralExpression::DecimalLiteral(n) => absy::ExpressionNode::from(n), + pest::LiteralExpression::HexLiteral(n) => absy::ExpressionNode::from(n), } } } @@ -611,8 +671,8 @@ impl<'ast> From> for absy::AssigneeNode<'ast> { } } -impl<'ast> From> for absy::UnresolvedTypeNode { - fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode { +impl<'ast> From> for absy::UnresolvedTypeNode<'ast> { + fn from(t: pest::Type<'ast>) -> absy::UnresolvedTypeNode<'ast> { use crate::absy::types::UnresolvedType; use crate::absy::NodeValue; @@ -642,21 +702,7 @@ impl<'ast> From> for absy::UnresolvedTypeNode { t.dimensions .into_iter() - .map(|s| match s { - pest::Expression::Constant(c) => match c { - pest::ConstantExpression::DecimalNumber(n) => { - str::parse::(&n.value).unwrap() - } - _ => unimplemented!( - "Array size should be a decimal number, found {}", - c.span().as_str() - ), - }, - e => unimplemented!( - "Array size should be constant, found {}", - e.span().as_str() - ), - }) + .map(absy::ExpressionNode::from) .rev() .fold(None, |acc, s| match acc { None => Some(UnresolvedType::array(inner_type.clone(), s)), @@ -690,10 +736,9 @@ mod tests { arguments: vec![], statements: vec![absy::Statement::Return( absy::ExpressionList { - expressions: vec![absy::Expression::FieldConstant(BigUint::from( - 42u32, - )) - .into()], + expressions: vec![ + absy::Expression::IntConstant(42usize.into()).into() + ], } .into(), ) @@ -771,10 +816,9 @@ mod tests { ], statements: vec![absy::Statement::Return( absy::ExpressionList { - expressions: vec![absy::Expression::FieldConstant(BigUint::from( - 42u32, - )) - .into()], + expressions: vec![ + absy::Expression::IntConstant(42usize.into()).into() + ], } .into(), ) @@ -800,7 +844,7 @@ mod tests { use super::*; /// Helper method to generate the ast for `def main(private {ty} a): return` which we use to check ty - fn wrap(ty: UnresolvedType) -> absy::Module<'static> { + fn wrap(ty: UnresolvedType<'static>) -> absy::Module<'static> { absy::Module { symbols: vec![absy::SymbolDeclaration { id: "main", @@ -834,21 +878,31 @@ mod tests { ("bool", UnresolvedType::Boolean), ( "field[2]", - UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 2), - ), - ( - "field[2][3]", - UnresolvedType::Array( - box UnresolvedType::Array(box UnresolvedType::FieldElement.mock(), 3) - .mock(), - 2, + absy::UnresolvedType::Array( + box absy::UnresolvedType::FieldElement.mock(), + absy::Expression::IntConstant(2usize.into()).mock(), ), ), ( - "bool[2][3]", - UnresolvedType::Array( - box UnresolvedType::Array(box UnresolvedType::Boolean.mock(), 3).mock(), - 2, + "field[2][3]", + absy::UnresolvedType::Array( + box absy::UnresolvedType::Array( + box absy::UnresolvedType::FieldElement.mock(), + absy::Expression::IntConstant(3usize.into()).mock(), + ) + .mock(), + absy::Expression::IntConstant(2usize.into()).mock(), + ), + ), + ( + "bool[2][3u32]", + absy::UnresolvedType::Array( + box absy::UnresolvedType::Array( + box absy::UnresolvedType::Boolean.mock(), + absy::Expression::U32Constant(3u32).mock(), + ) + .mock(), + absy::Expression::IntConstant(2usize.into()).mock(), ), ), ]; @@ -893,13 +947,13 @@ mod tests { // we basically accept `()?[]*` : an optional call at first, then only array accesses let vectors = vec![ - ("a", absy::Expression::Identifier("a").into()), + ("a", absy::Expression::Identifier("a")), ( "a[3]", absy::Expression::Select( box absy::Expression::Identifier("a").into(), box absy::RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(3u32)).into(), + absy::Expression::IntConstant(3usize.into()).into(), ) .into(), ), @@ -910,13 +964,13 @@ mod tests { box absy::Expression::Select( box absy::Expression::Identifier("a").into(), box absy::RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(3u32)).into(), + absy::Expression::IntConstant(3usize.into()).into(), ) .into(), ) .into(), box absy::RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(4u32)).into(), + absy::Expression::IntConstant(4usize.into()).into(), ) .into(), ), @@ -926,11 +980,12 @@ mod tests { absy::Expression::Select( box absy::Expression::FunctionCall( "a", - vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()], + None, + vec![absy::Expression::IntConstant(3usize.into()).into()], ) .into(), box absy::RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(4u32)).into(), + absy::Expression::IntConstant(4usize.into()).into(), ) .into(), ), @@ -941,17 +996,18 @@ mod tests { box absy::Expression::Select( box absy::Expression::FunctionCall( "a", - vec![absy::Expression::FieldConstant(BigUint::from(3u32)).into()], + None, + vec![absy::Expression::IntConstant(3usize.into()).into()], ) .into(), box absy::RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(4u32)).into(), + absy::Expression::IntConstant(4usize.into()).into(), ) .into(), ) .into(), box absy::RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(5u32)).into(), + absy::Expression::IntConstant(5usize.into()).into(), ) .into(), ), @@ -993,7 +1049,7 @@ mod tests { // For different definitions, we generate declarations // Case 1: `id = expr` where `expr` is not a function call // This is a simple assignment, doesn't implicitely declare a variable - // A `Definition` is generatedm and no `Declaration`s + // A `Definition` is generated and no `Declaration`s let definition = pest::DefinitionStatement { lhs: vec![pest::OptionallyTypedAssignee { @@ -1008,9 +1064,12 @@ mod tests { }, span: span.clone(), }], - expression: pest::Expression::Constant(pest::ConstantExpression::DecimalNumber( - pest::DecimalNumberExpression { - value: String::from("42"), + expression: pest::Expression::Literal(pest::LiteralExpression::DecimalLiteral( + pest::DecimalLiteralExpression { + value: pest::DecimalNumber { + span: Span::new(&"1", 0, 1).unwrap(), + }, + suffix: None, span: span.clone(), }, )), @@ -1049,7 +1108,11 @@ mod tests { span: span.clone(), }, accesses: vec![pest::Access::Call(pest::CallAccess { - expressions: vec![], + explicit_generics: None, + arguments: pest::Arguments { + expressions: vec![], + span: span.clone(), + }, span: span.clone(), })], span: span.clone(), @@ -1106,7 +1169,11 @@ mod tests { span: span.clone(), }, accesses: vec![pest::Access::Call(pest::CallAccess { - expressions: vec![], + explicit_generics: None, + arguments: pest::Arguments { + expressions: vec![], + span: span.clone(), + }, span: span.clone(), })], span: span.clone(), diff --git a/zokrates_core/src/absy/mod.rs b/zokrates_core/src/absy/mod.rs index 71d8703f..68c64eb4 100644 --- a/zokrates_core/src/absy/mod.rs +++ b/zokrates_core/src/absy/mod.rs @@ -105,7 +105,7 @@ impl<'ast> Module<'ast> { } } -pub type UnresolvedTypeNode = Node; +pub type UnresolvedTypeNode<'ast> = Node>; /// A struct type definition #[derive(Debug, Clone, PartialEq)] @@ -133,7 +133,7 @@ pub type StructDefinitionNode<'ast> = Node>; #[derive(Debug, Clone, PartialEq)] pub struct StructDefinitionField<'ast> { pub id: Identifier<'ast>, - pub ty: UnresolvedTypeNode, + pub ty: UnresolvedTypeNode<'ast>, } impl<'ast> fmt::Display for StructDefinitionField<'ast> { @@ -216,6 +216,8 @@ impl<'ast> fmt::Debug for Module<'ast> { } } +pub type ConstantGenericNode<'ast> = Node>; + /// A function defined locally #[derive(Clone, PartialEq)] pub struct Function<'ast> { @@ -224,13 +226,26 @@ pub struct Function<'ast> { /// Vector of statements that are executed when running the function pub statements: Vec>, /// function signature - pub signature: UnresolvedSignature, + pub signature: UnresolvedSignature<'ast>, } pub type FunctionNode<'ast> = Node>; impl<'ast> fmt::Display for Function<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if !self.signature.generics.is_empty() { + write!( + f, + "<{}>", + self.signature + .generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + write!( f, "({}):\n{}", @@ -294,6 +309,7 @@ impl<'ast> fmt::Display for Assignee<'ast> { } /// A statement in a `Function` +#[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq)] pub enum Statement<'ast> { Return(ExpressionListNode<'ast>), @@ -319,9 +335,9 @@ impl<'ast> fmt::Display for Statement<'ast> { Statement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), Statement::Assertion(ref e) => write!(f, "assert({})", e), Statement::For(ref var, ref start, ref stop, ref list) => { - write!(f, "for {} in {}..{} do\n", var, start, stop)?; + writeln!(f, "for {} in {}..{} do", var, start, stop)?; for l in list { - write!(f, "\t\t{}\n", l)?; + writeln!(f, "\t\t{}", l)?; } write!(f, "\tendfor") } @@ -348,9 +364,9 @@ impl<'ast> fmt::Debug for Statement<'ast> { } Statement::Assertion(ref e) => write!(f, "Assertion({:?})", e), Statement::For(ref var, ref start, ref stop, ref list) => { - write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?; + writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?; for l in list { - write!(f, "\t\t{:?}\n", l)?; + writeln!(f, "\t\t{:?}", l)?; } write!(f, "\tendfor") } @@ -454,11 +470,11 @@ impl<'ast> fmt::Display for Range<'ast> { self.from .as_ref() .map(|e| e.to_string()) - .unwrap_or("".to_string()), + .unwrap_or_else(|| "".to_string()), self.to .as_ref() .map(|e| e.to_string()) - .unwrap_or("".to_string()) + .unwrap_or_else(|| "".to_string()) ) } } @@ -472,6 +488,7 @@ impl<'ast> fmt::Debug for Range<'ast> { /// An expression #[derive(Clone, PartialEq)] pub enum Expression<'ast> { + IntConstant(BigUint), FieldConstant(BigUint), BooleanConstant(bool), U8Constant(u8), @@ -491,7 +508,11 @@ pub enum Expression<'ast> { Box>, Box>, ), - FunctionCall(FunctionIdentifier<'ast>, Vec>), + FunctionCall( + FunctionIdentifier<'ast>, + Option>>>, + Vec>, + ), Lt(Box>, Box>), Le(Box>, Box>), Eq(Box>, Box>), @@ -500,6 +521,7 @@ pub enum Expression<'ast> { And(Box>, Box>), Not(Box>), InlineArray(Vec>), + ArrayInitializer(Box>, Box>), InlineStruct(UserTypeId, Vec<(Identifier<'ast>, ExpressionNode<'ast>)>), Select(Box>, Box>), Member(Box>, Box>), @@ -520,6 +542,7 @@ impl<'ast> fmt::Display for Expression<'ast> { Expression::U8Constant(ref i) => write!(f, "{}", i), Expression::U16Constant(ref i) => write!(f, "{}", i), Expression::U32Constant(ref i) => write!(f, "{}", i), + Expression::IntConstant(ref i) => write!(f, "{}", i), Expression::Identifier(ref var) => write!(f, "{}", var), Expression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), Expression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), @@ -535,8 +558,21 @@ impl<'ast> fmt::Display for Expression<'ast> { "if {} then {} else {} fi", condition, consequent, alternative ), - Expression::FunctionCall(ref i, ref p) => { - write!(f, "{}(", i,)?; + Expression::FunctionCall(ref i, ref g, ref p) => { + if let Some(g) = g { + write!( + f, + "::<{}>", + g.iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| "_".into())) + .collect::>() + .join(", "), + )?; + } + write!(f, "{}(", i)?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -562,6 +598,7 @@ impl<'ast> fmt::Display for Expression<'ast> { } write!(f, "]") } + Expression::ArrayInitializer(ref e, ref count) => write!(f, "[{}; {}]", e, count), Expression::InlineStruct(ref id, ref members) => { write!(f, "{} {{", id)?; for (i, (member_id, e)) in members.iter().enumerate() { @@ -587,10 +624,11 @@ impl<'ast> fmt::Display for Expression<'ast> { impl<'ast> fmt::Debug for Expression<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Expression::U8Constant(ref i) => write!(f, "{:x}", i), - Expression::U16Constant(ref i) => write!(f, "{:x}", i), - Expression::U32Constant(ref i) => write!(f, "{:x}", i), - Expression::FieldConstant(ref i) => write!(f, "Num({:?})", i), + Expression::U8Constant(ref i) => write!(f, "U8({:x})", i), + Expression::U16Constant(ref i) => write!(f, "U16({:x})", i), + Expression::U32Constant(ref i) => write!(f, "U32({:x})", i), + Expression::FieldConstant(ref i) => write!(f, "Field({:?})", i), + Expression::IntConstant(ref i) => write!(f, "Int({:?})", i), Expression::Identifier(ref var) => write!(f, "Ide({})", var), Expression::Add(ref lhs, ref rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs), Expression::Sub(ref lhs, ref rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs), @@ -606,8 +644,8 @@ impl<'ast> fmt::Debug for Expression<'ast> { "IfElse({:?}, {:?}, {:?})", condition, consequent, alternative ), - Expression::FunctionCall(ref i, ref p) => { - write!(f, "FunctionCall({:?}, (", i)?; + Expression::FunctionCall(ref g, ref i, ref p) => { + write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } @@ -623,6 +661,9 @@ impl<'ast> fmt::Debug for Expression<'ast> { f.debug_list().entries(exprs.iter()).finish()?; write!(f, "]") } + Expression::ArrayInitializer(ref e, ref count) => { + write!(f, "ArrayInitializer({:?}, {:?})", e, count) + } Expression::InlineStruct(ref id, ref members) => { write!(f, "InlineStruct({:?}, [", id)?; f.debug_list().entries(members.iter()).finish()?; @@ -645,21 +686,13 @@ impl<'ast> fmt::Debug for Expression<'ast> { } /// A list of expressions, used in return statements -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Default)] pub struct ExpressionList<'ast> { pub expressions: Vec>, } pub type ExpressionListNode<'ast> = Node>; -impl<'ast> ExpressionList<'ast> { - pub fn new() -> ExpressionList<'ast> { - ExpressionList { - expressions: vec![], - } - } -} - impl<'ast> fmt::Display for ExpressionList<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, param) in self.expressions.iter().enumerate() { diff --git a/zokrates_core/src/absy/node.rs b/zokrates_core/src/absy/node.rs index cf8d3fb4..304a5a0e 100644 --- a/zokrates_core/src/absy/node.rs +++ b/zokrates_core/src/absy/node.rs @@ -81,7 +81,7 @@ impl<'ast> NodeValue for ExpressionList<'ast> {} impl<'ast> NodeValue for Assignee<'ast> {} impl<'ast> NodeValue for Statement<'ast> {} impl<'ast> NodeValue for SymbolDeclaration<'ast> {} -impl NodeValue for UnresolvedType {} +impl<'ast> NodeValue for UnresolvedType<'ast> {} impl<'ast> NodeValue for StructDefinition<'ast> {} impl<'ast> NodeValue for StructDefinitionField<'ast> {} impl<'ast> NodeValue for Function<'ast> {} @@ -92,6 +92,7 @@ impl<'ast> NodeValue for Parameter<'ast> {} impl<'ast> NodeValue for Import<'ast> {} impl<'ast> NodeValue for Spread<'ast> {} impl<'ast> NodeValue for Range<'ast> {} +impl<'ast> NodeValue for Identifier<'ast> {} impl PartialEq for Node { fn eq(&self, other: &Node) -> bool { diff --git a/zokrates_core/src/absy/types.rs b/zokrates_core/src/absy/types.rs index 49658fa2..439f9ca7 100644 --- a/zokrates_core/src/absy/types.rs +++ b/zokrates_core/src/absy/types.rs @@ -1,3 +1,4 @@ +use crate::absy::ExpressionNode; use crate::absy::UnresolvedTypeNode; use std::fmt; @@ -8,15 +9,15 @@ pub type MemberId = String; pub type UserTypeId = String; #[derive(Clone, PartialEq, Debug)] -pub enum UnresolvedType { +pub enum UnresolvedType<'ast> { FieldElement, Boolean, Uint(usize), - Array(Box, usize), + Array(Box>, ExpressionNode<'ast>), User(UserTypeId), } -impl fmt::Display for UnresolvedType { +impl<'ast> fmt::Display for UnresolvedType<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { UnresolvedType::FieldElement => write!(f, "field"), @@ -28,8 +29,8 @@ impl fmt::Display for UnresolvedType { } } -impl UnresolvedType { - pub fn array(ty: UnresolvedTypeNode, size: usize) -> Self { +impl<'ast> UnresolvedType<'ast> { + pub fn array(ty: UnresolvedTypeNode<'ast>, size: ExpressionNode<'ast>) -> Self { UnresolvedType::Array(box ty, size) } } @@ -39,17 +40,19 @@ pub type FunctionIdentifier<'ast> = &'ast str; pub use self::signature::UnresolvedSignature; mod signature { + use crate::absy::ConstantGenericNode; use std::fmt; use crate::absy::UnresolvedTypeNode; - #[derive(Clone, PartialEq)] - pub struct UnresolvedSignature { - pub inputs: Vec, - pub outputs: Vec, + #[derive(Clone, PartialEq, Default)] + pub struct UnresolvedSignature<'ast> { + pub generics: Vec>, + pub inputs: Vec>, + pub outputs: Vec>, } - impl fmt::Debug for UnresolvedSignature { + impl<'ast> fmt::Debug for UnresolvedSignature<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -59,7 +62,7 @@ mod signature { } } - impl fmt::Display for UnresolvedSignature { + impl<'ast> fmt::Display for UnresolvedSignature<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { @@ -79,20 +82,22 @@ mod signature { } } - impl UnresolvedSignature { - pub fn new() -> UnresolvedSignature { - UnresolvedSignature { - inputs: vec![], - outputs: vec![], - } + impl<'ast> UnresolvedSignature<'ast> { + pub fn new() -> UnresolvedSignature<'ast> { + UnresolvedSignature::default() } - pub fn inputs(mut self, inputs: Vec) -> Self { + pub fn generics(mut self, generics: Vec>) -> Self { + self.generics = generics; + self + } + + pub fn inputs(mut self, inputs: Vec>) -> Self { self.inputs = inputs; self } - pub fn outputs(mut self, outputs: Vec) -> Self { + pub fn outputs(mut self, outputs: Vec>) -> Self { self.outputs = outputs; self } diff --git a/zokrates_core/src/absy/variable.rs b/zokrates_core/src/absy/variable.rs index f03b3f0d..6b8f9b60 100644 --- a/zokrates_core/src/absy/variable.rs +++ b/zokrates_core/src/absy/variable.rs @@ -7,21 +7,21 @@ use crate::absy::Identifier; #[derive(Clone, PartialEq)] pub struct Variable<'ast> { pub id: Identifier<'ast>, - pub _type: UnresolvedTypeNode, + pub _type: UnresolvedTypeNode<'ast>, } pub type VariableNode<'ast> = Node>; impl<'ast> Variable<'ast> { - pub fn new>(id: S, t: UnresolvedTypeNode) -> Variable<'ast> { + pub fn new>(id: S, t: UnresolvedTypeNode<'ast>) -> Variable<'ast> { Variable { id: id.into(), _type: t, } } - pub fn get_type(&self) -> UnresolvedType { - self._type.value.clone() + pub fn get_type(&self) -> &UnresolvedType<'ast> { + &self._type.value } } diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index 7711a52a..1e32927e 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -9,6 +9,7 @@ use crate::imports::{self, Importer}; use crate::ir; use crate::macros; use crate::semantics::{self, Checker}; +use crate::static_analysis; use crate::static_analysis::Analyse; use crate::typed_absy::abi::Abi; use crate::zir::ZirProgram; @@ -55,6 +56,7 @@ pub enum CompileErrorInner { MacroError(macros::Error), SemanticError(semantics::ErrorInner), ReadError(io::Error), + AnalysisError(static_analysis::Error), } impl CompileErrorInner { @@ -129,6 +131,12 @@ impl From for CompileError { } } +impl From for CompileErrorInner { + fn from(error: static_analysis::Error) -> Self { + CompileErrorInner::AnalysisError(error) + } +} + impl fmt::Display for CompileErrorInner { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -137,6 +145,7 @@ impl fmt::Display for CompileErrorInner { CompileErrorInner::SemanticError(ref e) => write!(f, "{}", e), CompileErrorInner::ReadError(ref e) => write!(f, "{}", e), CompileErrorInner::ImportError(ref e) => write!(f, "{}", e), + CompileErrorInner::AnalysisError(ref e) => write!(f, "{}", e), } } } @@ -179,7 +188,7 @@ pub fn compile>( }) } -pub fn check<'ast, T: Field, E: Into>( +pub fn check>( source: String, location: FilePath, resolver: Option<&dyn Resolver>, @@ -196,19 +205,18 @@ fn check_with_arena<'ast, T: Field, E: Into>( arena: &'ast Arena, ) -> Result<(ZirProgram<'ast, T>, Abi), CompileErrors> { let source = arena.alloc(source); - let compiled = compile_program::(source, location.clone(), resolver, &arena)?; + let compiled = compile_program::(source, location, resolver, &arena)?; // check semantics - let typed_ast = Checker::check(compiled).map_err(|errors| { - CompileErrors(errors.into_iter().map(|e| CompileError::from(e)).collect()) - })?; + let typed_ast = Checker::check(compiled) + .map_err(|errors| CompileErrors(errors.into_iter().map(CompileError::from).collect()))?; - let abi = typed_ast.abi(); + let main_module = typed_ast.main.clone(); // analyse (unroll and constant propagation) - let typed_ast = typed_ast.analyse(); - - Ok((typed_ast, abi)) + typed_ast + .analyse() + .map_err(|e| CompileErrors(vec![CompileErrorInner::from(e).in_file(&main_module)])) } pub fn compile_program<'ast, T: Field, E: Into>( @@ -244,7 +252,7 @@ pub fn compile_module<'ast, T: Field, E: Into>( let module_without_imports: Module = Module::from(ast); - Importer::new().apply_imports::( + Importer::apply_imports::( module_without_imports, location.clone(), resolver, @@ -380,17 +388,17 @@ struct Bar { field a } inputs: vec![AbiInput { name: "f".into(), public: true, - ty: Type::Struct(StructType::new( + ty: ConcreteType::Struct(ConcreteStructType::new( "foo".into(), "Foo".into(), - vec![StructMember { + vec![ConcreteStructMember { id: "b".into(), - ty: box Type::Struct(StructType::new( + ty: box ConcreteType::Struct(ConcreteStructType::new( "bar".into(), "Bar".into(), - vec![StructMember { + vec![ConcreteStructMember { id: "a".into(), - ty: box Type::FieldElement + ty: box ConcreteType::FieldElement }] )) }] diff --git a/zokrates_core/src/embed.rs b/zokrates_core/src/embed.rs index 76b1cd93..b5082579 100644 --- a/zokrates_core/src/embed.rs +++ b/zokrates_core/src/embed.rs @@ -3,7 +3,9 @@ use crate::flat_absy::{ FlatVariable, }; use crate::solvers::Solver; -use crate::typed_absy::types::{FunctionKey, Signature, Type}; +use crate::typed_absy::types::{ + ConcreteGenericsAssignment, Constant, DeclarationSignature, DeclarationType, +}; use std::collections::HashMap; use zokrates_field::{Bn128Field, Field}; @@ -18,11 +20,12 @@ cfg_if::cfg_if! { /// A low level function that contains non-deterministic introduction of variables. It is carried out as is until /// the flattening step when it can be inlined. -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] pub enum FlatEmbed { + U32ToField, #[cfg(feature = "bellman")] Sha256Round, - Unpack(usize), + Unpack, U8ToBits, U16ToBits, U32ToBits, @@ -32,48 +35,88 @@ pub enum FlatEmbed { } impl FlatEmbed { - pub fn signature(&self) -> Signature { + pub fn signature(&self) -> DeclarationSignature<'static> { match self { + FlatEmbed::U32ToField => DeclarationSignature::new() + .inputs(vec![DeclarationType::uint(32)]) + .outputs(vec![DeclarationType::FieldElement]), + FlatEmbed::Unpack => DeclarationSignature::new() + .generics(vec![Some(Constant::Generic("N"))]) + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + "N", + ))]), + FlatEmbed::U8ToBits => DeclarationSignature::new() + .inputs(vec![DeclarationType::uint(8)]) + .outputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 8usize, + ))]), + FlatEmbed::U16ToBits => DeclarationSignature::new() + .inputs(vec![DeclarationType::uint(16)]) + .outputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 16usize, + ))]), + FlatEmbed::U32ToBits => DeclarationSignature::new() + .inputs(vec![DeclarationType::uint(32)]) + .outputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 32usize, + ))]), + FlatEmbed::U8FromBits => DeclarationSignature::new() + .outputs(vec![DeclarationType::uint(8)]) + .inputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 8usize, + ))]), + FlatEmbed::U16FromBits => DeclarationSignature::new() + .outputs(vec![DeclarationType::uint(16)]) + .inputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 16usize, + ))]), + FlatEmbed::U32FromBits => DeclarationSignature::new() + .outputs(vec![DeclarationType::uint(32)]) + .inputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 32usize, + ))]), #[cfg(feature = "bellman")] - FlatEmbed::Sha256Round => Signature::new() + FlatEmbed::Sha256Round => DeclarationSignature::new() .inputs(vec![ - Type::array(Type::Boolean, 512), - Type::array(Type::Boolean, 256), + DeclarationType::array((DeclarationType::Boolean, 512usize)), + DeclarationType::array((DeclarationType::Boolean, 256usize)), ]) - .outputs(vec![Type::array(Type::Boolean, 256)]), - FlatEmbed::Unpack(bitwidth) => Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::array(Type::Boolean, *bitwidth)]), - FlatEmbed::U8ToBits => Signature::new() - .inputs(vec![Type::uint(8)]) - .outputs(vec![Type::array(Type::Boolean, 8)]), - FlatEmbed::U16ToBits => Signature::new() - .inputs(vec![Type::uint(16)]) - .outputs(vec![Type::array(Type::Boolean, 16)]), - FlatEmbed::U32ToBits => Signature::new() - .inputs(vec![Type::uint(32)]) - .outputs(vec![Type::array(Type::Boolean, 32)]), - FlatEmbed::U8FromBits => Signature::new() - .outputs(vec![Type::uint(8)]) - .inputs(vec![Type::array(Type::Boolean, 8)]), - FlatEmbed::U16FromBits => Signature::new() - .outputs(vec![Type::uint(16)]) - .inputs(vec![Type::array(Type::Boolean, 16)]), - FlatEmbed::U32FromBits => Signature::new() - .outputs(vec![Type::uint(32)]) - .inputs(vec![Type::array(Type::Boolean, 32)]), + .outputs(vec![DeclarationType::array(( + DeclarationType::Boolean, + 256usize, + ))]), } } - pub fn key(&self) -> FunctionKey<'static> { - FunctionKey::with_id(self.id()).signature(self.signature()) + pub fn generics<'ast>(&self, assignment: &ConcreteGenericsAssignment<'ast>) -> Vec { + let gen = self + .signature() + .generics + .into_iter() + .map(|c| match c.unwrap() { + Constant::Generic(g) => g, + _ => unreachable!(), + }); + + assert_eq!(gen.len(), assignment.0.len()); + gen.map(|g| *assignment.0.get(&g).clone().unwrap() as u32) + .collect() } pub fn id(&self) -> &'static str { match self { + FlatEmbed::U32ToField => "_U32_TO_FIELD", #[cfg(feature = "bellman")] FlatEmbed::Sha256Round => "_SHA256_ROUND", - FlatEmbed::Unpack(_) => "_UNPACK", + FlatEmbed::Unpack => "_UNPACK", FlatEmbed::U8ToBits => "_U8_TO_BITS", FlatEmbed::U16ToBits => "_U16_TO_BITS", FlatEmbed::U32ToBits => "_U32_TO_BITS", @@ -84,11 +127,11 @@ impl FlatEmbed { } /// Actually get the `FlatFunction` that this `FlatEmbed` represents - pub fn synthetize(&self) -> FlatFunction { + pub fn synthetize(&self, generics: &[u32]) -> FlatFunction { match self { #[cfg(feature = "bellman")] FlatEmbed::Sha256Round => sha256_round(), - FlatEmbed::Unpack(bitwidth) => unpack_to_bitwidth(*bitwidth), + FlatEmbed::Unpack => unpack_to_bitwidth(generics[0] as usize), _ => unreachable!(), } } @@ -101,7 +144,7 @@ fn flat_expression_from_vec(v: &[(usize, E::Fr)]) -> FlatEx match v.len() { 0 => FlatExpression::Number(T::zero()), 1 => { - let (key, val) = v[0].clone(); + let (key, val) = v[0]; let mut res: Vec = vec![]; val.into_repr().write_le(&mut res).unwrap(); FlatExpression::Mult( @@ -152,7 +195,7 @@ pub fn sha256_round() -> FlatFunction { let output_indices = output_indices.into_iter(); let variable_count = r1cs.aux_count + 1; // auxiliary and ONE // indices of the sha256round constraint system variables - let cs_indices = (0..variable_count).into_iter(); + let cs_indices = 0..variable_count; // indices of the arguments to the function // apply an offset of `variable_count` to get the indice of our dummy `input` argument let input_argument_indices = input_indices @@ -180,7 +223,7 @@ pub fn sha256_round() -> FlatFunction { ); let input_binding_statements = // bind input and current_hash to inputs - input_indices.clone().chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| { + input_indices.chain(current_hash_indices).zip(input_argument_indices.clone().chain(current_hash_argument_indices.clone())).map(|(cs_index, argument_index)| { FlatStatement::Condition( FlatVariable::new(cs_index).into(), FlatVariable::new(argument_index).into(), @@ -194,7 +237,7 @@ pub fn sha256_round() -> FlatFunction { .collect(); // insert a directive to set the witness based on the bellman gadget and inputs let directive_statement = FlatStatement::Directive(FlatDirective { - outputs: cs_indices.map(|i| FlatVariable::new(i)).collect(), + outputs: cs_indices.map(FlatVariable::new).collect(), inputs: input_argument_indices .chain(current_hash_argument_indices) .map(|i| FlatVariable::new(i).into()) @@ -224,7 +267,7 @@ fn use_variable( ) -> FlatVariable { let var = FlatVariable::new(*index); layout.insert(name, var); - *index = *index + 1; + *index += 1; var } @@ -255,7 +298,7 @@ pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { let directive_inputs = vec![FlatExpression::Identifier(use_variable( &mut layout, - format!("i0"), + "i0".into(), &mut counter, ))]; @@ -268,7 +311,7 @@ pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { let outputs = directive_outputs .iter() .enumerate() - .map(|(_, o)| FlatExpression::Identifier(o.clone())) + .map(|(_, o)| FlatExpression::Identifier(*o)) .collect::>(); // o253, o252, ... o{253 - (bit_width - 1)} are bits @@ -308,7 +351,7 @@ pub fn unpack_to_bitwidth(bit_width: usize) -> FlatFunction { FlatStatement::Directive(FlatDirective { inputs: directive_inputs, outputs: directive_outputs, - solver: solver, + solver, }), ); @@ -436,7 +479,7 @@ mod tests { private: vec![true; 768], }; - let input = (0..512) + let input: Vec<_> = (0..512) .map(|_| 0) .chain((0..256).map(|_| 1)) .map(|i| Bn128Field::from(i)) diff --git a/zokrates_core/src/flat_absy/flat_parameter.rs b/zokrates_core/src/flat_absy/flat_parameter.rs index 788a3ac4..0b9481cb 100644 --- a/zokrates_core/src/flat_absy/flat_parameter.rs +++ b/zokrates_core/src/flat_absy/flat_parameter.rs @@ -41,7 +41,7 @@ impl FlatParameter { substitution: &HashMap, ) -> FlatParameter { FlatParameter { - id: substitution.get(&self.id).unwrap().clone(), + id: *substitution.get(&self.id).unwrap(), private: self.private, } } diff --git a/zokrates_core/src/flat_absy/flat_variable.rs b/zokrates_core/src/flat_absy/flat_variable.rs index 5fc96c96..2679cd99 100644 --- a/zokrates_core/src/flat_absy/flat_variable.rs +++ b/zokrates_core/src/flat_absy/flat_variable.rs @@ -45,7 +45,7 @@ impl FlatVariable { Ok(FlatVariable::public(v)) } None => { - let mut private = s.split("_"); + let mut private = s.split('_'); match private.nth(1) { Some(v) => { let v = v.parse().map_err(|_| s)?; diff --git a/zokrates_core/src/flat_absy/mod.rs b/zokrates_core/src/flat_absy/mod.rs index b96c0c22..139032dd 100644 --- a/zokrates_core/src/flat_absy/mod.rs +++ b/zokrates_core/src/flat_absy/mod.rs @@ -155,23 +155,13 @@ impl FlatStatement { } } -#[derive(Clone, Hash, Debug)] +#[derive(Clone, Hash, Debug, PartialEq, Eq)] pub struct FlatDirective { pub inputs: Vec>, pub outputs: Vec, pub solver: Solver, } -impl PartialEq for FlatDirective { - fn eq(&self, other: &Self) -> bool { - self.inputs.eq(&other.inputs) - && self.outputs.eq(&other.outputs) - && self.solver.eq(&other.solver) - } -} - -impl Eq for FlatDirective {} - impl FlatDirective { pub fn new>>( outputs: Vec, @@ -249,12 +239,18 @@ impl FlatExpression { FlatExpression::Add(ref x, ref y) | FlatExpression::Sub(ref x, ref y) => { x.is_linear() && y.is_linear() } - FlatExpression::Mult(ref x, ref y) => match (x.clone(), y.clone()) { + FlatExpression::Mult(ref x, ref y) => matches!( + (x.clone(), y.clone()), (box FlatExpression::Number(_), box FlatExpression::Number(_)) - | (box FlatExpression::Number(_), box FlatExpression::Identifier(_)) - | (box FlatExpression::Identifier(_), box FlatExpression::Number(_)) => true, - _ => false, - }, + | ( + box FlatExpression::Number(_), + box FlatExpression::Identifier(_) + ) + | ( + box FlatExpression::Identifier(_), + box FlatExpression::Number(_) + ) + ), } } } diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 1136544b..fbec3419 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::needless_collect)] + //! Module containing the `Flattener` to process a program that is R1CS-able. //! //! @file flatten.rs @@ -8,11 +10,13 @@ mod utils; use self::utils::flat_expression_from_bits; +use crate::ir::Interpreter; use crate::compile::CompileConfig; +use crate::embed::FlatEmbed; use crate::flat_absy::*; -use crate::solvers::{Executable, Solver}; -use crate::zir::types::{FunctionIdentifier, FunctionKey, Signature, Type, UBitwidth}; +use crate::solvers::Solver; +use crate::zir::types::{Type, UBitwidth}; use crate::zir::*; use std::collections::hash_map::Entry; use std::collections::HashMap; @@ -29,8 +33,6 @@ pub struct Flattener<'ast, T: Field> { next_var_idx: usize, /// `FlatVariable`s corresponding to each `Identifier` layout: HashMap, FlatVariable>, - /// Cached `FlatFunction`s to avoid re-flattening them - flat_cache: HashMap, FlatFunction>, /// Cached bit decompositions to avoid re-generating them bits_cache: HashMap, Vec>>, } @@ -59,7 +61,6 @@ trait Flatten<'ast, T: Field>: TryFrom, Error = ()> + IfE fn flatten( self, flattener: &mut Flattener<'ast, T>, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, ) -> Self::Output; } @@ -70,10 +71,9 @@ impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, ) -> Self::Output { - flattener.flatten_field_expression(symbols, statements_flattened, self) + flattener.flatten_field_expression(statements_flattened, self) } } @@ -83,10 +83,9 @@ impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, ) -> Self::Output { - flattener.flatten_uint_expression(symbols, statements_flattened, self) + flattener.flatten_uint_expression(statements_flattened, self) } } @@ -96,10 +95,9 @@ impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { fn flatten( self, flattener: &mut Flattener<'ast, T>, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, ) -> Self::Output { - flattener.flatten_boolean_expression(symbols, statements_flattened, self) + flattener.flatten_boolean_expression(statements_flattened, self) } } @@ -160,7 +158,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { config, next_var_idx: 0, layout: HashMap::new(), - flat_cache: HashMap::new(), bits_cache: HashMap::new(), } } @@ -172,7 +169,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened: &mut FlatStatements, ) -> FlatVariable { match e { - FlatExpression::Identifier(id) => id.into(), + FlatExpression::Identifier(id) => id, e => { let res = self.use_sym(); statements_flattened.push(FlatStatement::Definition(res, e)); @@ -251,10 +248,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // sizeUnknown is not changing in this case // We sill have to assign the old value to the variable of the current run // This trivial definition will later be removed by the optimiser - FlatStatement::Definition( - size_unknown[i + 1].into(), - size_unknown[i].into(), - ), + FlatStatement::Definition(size_unknown[i + 1], size_unknown[i].into()), ); } @@ -287,7 +281,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `condition` - the condition as a `BooleanExpression`. /// * `consequence` - the consequence of type U. @@ -296,17 +289,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * U is the type of the expression fn flatten_if_else_expression>( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, condition: BooleanExpression<'ast, T>, consequence: U, alternative: U, ) -> FlatUExpression { - let condition = self.flatten_boolean_expression(symbols, statements_flattened, condition); + let condition = self.flatten_boolean_expression(statements_flattened, condition); - let consequence = consequence.flatten(self, symbols, statements_flattened); + let consequence = consequence.flatten(self, statements_flattened); - let alternative = alternative.flatten(self, symbols, statements_flattened); + let alternative = alternative.flatten(self, statements_flattened); let condition_id = self.use_sym(); statements_flattened.push(FlatStatement::Definition(condition_id, condition)); @@ -360,7 +352,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expression` - `BooleanExpression` that will be flattened. /// @@ -370,16 +361,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// * in order to preserve composability. fn flatten_boolean_expression( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { // those will be booleans in the future match expression { BooleanExpression::Identifier(x) => { - FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()) + FlatExpression::Identifier(*self.layout.get(&x).unwrap()) } - BooleanExpression::Lt(box lhs, box rhs) => { + BooleanExpression::FieldLt(box lhs, box rhs) => { // Get the bit width to know the size of the binary decompositions for this Field let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // making sure we don't overflow, assert here? @@ -387,10 +377,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { // We know from semantic checking that lhs and rhs have the same type // What the expression will flatten to depends on that type - let lhs_flattened = - self.flatten_field_expression(symbols, statements_flattened, lhs); - let rhs_flattened = - self.flatten_field_expression(symbols, statements_flattened, rhs); + let lhs_flattened = self.flatten_field_expression(statements_flattened, lhs); + let rhs_flattened = self.flatten_field_expression(statements_flattened, rhs); // lhs let lhs_id = self.define(lhs_flattened, statements_flattened); @@ -411,12 +399,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { ))); // bitness checks - for i in 0..safe_width { + for bit in lhs_bits_be.iter().take(safe_width) { statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(lhs_bits_be[i]), + FlatExpression::Identifier(*bit), FlatExpression::Mult( - box FlatExpression::Identifier(lhs_bits_be[i]), - box FlatExpression::Identifier(lhs_bits_be[i]), + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), ), )); } @@ -424,11 +412,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { // bit decomposition check let mut lhs_sum = FlatExpression::Number(T::from(0)); - for i in 0..safe_width { + for (i, bit) in lhs_bits_be.iter().enumerate().take(safe_width) { lhs_sum = FlatExpression::Add( box lhs_sum, box FlatExpression::Mult( - box FlatExpression::Identifier(lhs_bits_be[i]), + box FlatExpression::Identifier(*bit), box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)), ), ); @@ -457,12 +445,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { ))); // bitness checks - for i in 0..safe_width { + for bit in rhs_bits_be.iter().take(safe_width) { statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(rhs_bits_be[i]), + FlatExpression::Identifier(*bit), FlatExpression::Mult( - box FlatExpression::Identifier(rhs_bits_be[i]), - box FlatExpression::Identifier(rhs_bits_be[i]), + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), ), )); } @@ -470,11 +458,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { // bit decomposition check let mut rhs_sum = FlatExpression::Number(T::from(0)); - for i in 0..safe_width { + for (i, bit) in rhs_bits_be.iter().enumerate().take(safe_width) { rhs_sum = FlatExpression::Add( box rhs_sum, box FlatExpression::Mult( - box FlatExpression::Identifier(rhs_bits_be[i]), + box FlatExpression::Identifier(*bit), box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)), ), ); @@ -510,12 +498,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { ))); // bitness checks - for i in 0..bit_width { + for bit in sub_bits_be.iter().take(bit_width) { statements_flattened.push(FlatStatement::Condition( - FlatExpression::Identifier(sub_bits_be[i]), + FlatExpression::Identifier(*bit), FlatExpression::Mult( - box FlatExpression::Identifier(sub_bits_be[i]), - box FlatExpression::Identifier(sub_bits_be[i]), + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), ), )); } @@ -530,11 +518,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { // sum(sym_b{i} * 2**i) let mut expr = FlatExpression::Number(T::from(0)); - for i in 0..bit_width { + for (i, bit) in sub_bits_be.iter().enumerate().take(bit_width) { expr = FlatExpression::Add( box expr, box FlatExpression::Mult( - box FlatExpression::Identifier(sub_bits_be[i]), + box FlatExpression::Identifier(*bit), box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)), ), ); @@ -546,8 +534,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { } BooleanExpression::BoolEq(box lhs, box rhs) => { // lhs and rhs are booleans, they flatten to 0 or 1 - let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs); - let y = self.flatten_boolean_expression(symbols, statements_flattened, rhs); + let x = self.flatten_boolean_expression(statements_flattened, lhs); + let y = self.flatten_boolean_expression(statements_flattened, rhs); // Wanted: Not(X - Y)**2 which is an XNOR // We know that X and Y are [0, 1] // (X - Y) can become a negative values, which is why squaring the result is needed @@ -588,7 +576,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { let name_m = self.use_sym(); let x = self.flatten_field_expression( - symbols, statements_flattened, FieldElementExpression::Sub(box lhs, box rhs), ); @@ -632,10 +619,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(rhs.metadata.clone().unwrap().should_reduce.to_bool()); let lhs = self - .flatten_uint_expression(symbols, statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, lhs) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(symbols, statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, rhs) .get_field_unchecked(); let x = FlatExpression::Sub(box lhs, box rhs); @@ -662,32 +649,117 @@ impl<'ast, T: Field> Flattener<'ast, T> { res } - BooleanExpression::Le(box lhs, box rhs) => { + BooleanExpression::FieldLe(box lhs, box rhs) => { let lt = self.flatten_boolean_expression( - symbols, statements_flattened, - BooleanExpression::Lt(box lhs.clone(), box rhs.clone()), + BooleanExpression::FieldLt(box lhs.clone(), box rhs.clone()), ); let eq = self.flatten_boolean_expression( - symbols, statements_flattened, BooleanExpression::FieldEq(box lhs.clone(), box rhs.clone()), ); FlatExpression::Add(box eq, box lt) } - BooleanExpression::Gt(lhs, rhs) => self.flatten_boolean_expression( - symbols, + BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression( statements_flattened, - BooleanExpression::Lt(rhs, lhs), + BooleanExpression::FieldLt(rhs, lhs), ), - BooleanExpression::Ge(lhs, rhs) => self.flatten_boolean_expression( - symbols, + BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression( statements_flattened, - BooleanExpression::Le(rhs, lhs), + BooleanExpression::FieldLe(rhs, lhs), + ), + BooleanExpression::UintLt(box lhs, box rhs) => { + let lhs_flattened = self.flatten_uint_expression(statements_flattened, lhs); + let rhs_flattened = self.flatten_uint_expression(statements_flattened, rhs); + + // Get the bit width to know the size of the binary decompositions for this Field + // This is not this uint bitwidth + let bit_width = T::get_required_bits(); + + // lhs + let lhs_id = self.define(lhs_flattened.get_field_unchecked(), statements_flattened); + let rhs_id = self.define(rhs_flattened.get_field_unchecked(), statements_flattened); + + // sym := (lhs * 2) - (rhs * 2) + let subtraction_result = FlatExpression::Sub( + box FlatExpression::Mult( + box FlatExpression::Number(T::from(2)), + box FlatExpression::Identifier(lhs_id), + ), + box FlatExpression::Mult( + box FlatExpression::Number(T::from(2)), + box FlatExpression::Identifier(rhs_id), + ), + ); + + // define variables for the bits + let sub_bits_be: Vec = + (0..bit_width).map(|_| self.use_sym()).collect(); + + // add a directive to get the bits + statements_flattened.push(FlatStatement::Directive(FlatDirective::new( + sub_bits_be.clone(), + Solver::bits(bit_width), + vec![subtraction_result.clone()], + ))); + + // bitness checks + for bit in sub_bits_be.iter().take(bit_width) { + statements_flattened.push(FlatStatement::Condition( + FlatExpression::Identifier(*bit), + FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Identifier(*bit), + ), + )); + } + + // check that the decomposition is in the field with a strict `< p` checks + self.strict_le_check( + statements_flattened, + &T::max_value_bit_vector_be(), + sub_bits_be.clone(), + ); + + // sum(sym_b{i} * 2**i) + let mut expr = FlatExpression::Number(T::from(0)); + + for (i, bit) in sub_bits_be.iter().enumerate().take(bit_width) { + expr = FlatExpression::Add( + box expr, + box FlatExpression::Mult( + box FlatExpression::Identifier(*bit), + box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)), + ), + ); + } + + statements_flattened.push(FlatStatement::Condition(subtraction_result, expr)); + + FlatExpression::Identifier(sub_bits_be[bit_width - 1]) + } + BooleanExpression::UintLe(box lhs, box rhs) => { + let lt = self.flatten_boolean_expression( + statements_flattened, + BooleanExpression::UintLt(box lhs.clone(), box rhs.clone()), + ); + let eq = self.flatten_boolean_expression( + statements_flattened, + BooleanExpression::UintEq(box lhs.clone(), box rhs.clone()), + ); + FlatExpression::Add(box eq, box lt) + } + BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression( + statements_flattened, + BooleanExpression::UintLt(rhs, lhs), + ), + BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression( + statements_flattened, + BooleanExpression::UintLe(rhs, lhs), ), BooleanExpression::Or(box lhs, box rhs) => { - let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs); - let y = self.flatten_boolean_expression(symbols, statements_flattened, rhs); + let x = self.flatten_boolean_expression(statements_flattened, lhs); + let y = self.flatten_boolean_expression(statements_flattened, rhs); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); statements_flattened.push(FlatStatement::Directive(FlatDirective { @@ -705,8 +777,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { name_x_or_y.into() } BooleanExpression::And(box lhs, box rhs) => { - let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs); - let y = self.flatten_boolean_expression(symbols, statements_flattened, rhs); + let x = self.flatten_boolean_expression(statements_flattened, lhs); + let y = self.flatten_boolean_expression(statements_flattened, rhs); let name_x_and_y = self.use_sym(); assert!(x.is_linear() && y.is_linear()); @@ -718,7 +790,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Identifier(name_x_and_y) } BooleanExpression::Not(box exp) => { - let x = self.flatten_boolean_expression(symbols, statements_flattened, exp); + let x = self.flatten_boolean_expression(statements_flattened, exp); FlatExpression::Sub(box FlatExpression::Number(T::one()), box x) } BooleanExpression::Value(b) => FlatExpression::Number(match b { @@ -727,38 +799,32 @@ impl<'ast, T: Field> Flattener<'ast, T> { }), BooleanExpression::IfElse(box condition, box consequence, box alternative) => self .flatten_if_else_expression( - symbols, statements_flattened, condition, consequence, alternative, ) - .get_field_unchecked() - .clone(), + .get_field_unchecked(), } } fn flatten_u_to_bits( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, expression: ZirExpression<'ast, T>, bitwidth: UBitwidth, ) -> Vec> { let expression = UExpression::try_from(expression).unwrap(); let from = expression.metadata.clone().unwrap().bitwidth(); - let p = self.flatten_uint_expression(symbols, statements_flattened, expression); - let bits = self - .get_bits(p, from as usize, bitwidth, statements_flattened) + let p = self.flatten_uint_expression(statements_flattened, expression); + self.get_bits(p, from as usize, bitwidth, statements_flattened) .into_iter() - .map(|b| FlatUExpression::with_field(b)) - .collect(); - bits + .map(FlatUExpression::with_field) + .collect() } fn flatten_bits_to_u( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, bits: Vec>, bitwidth: UBitwidth, @@ -767,7 +833,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let bits: Vec<_> = bits .into_iter() .map(|p| { - self.flatten_expression(symbols, statements_flattened, p) + self.flatten_expression(statements_flattened, p) .get_field_unchecked() }) .collect(); @@ -779,76 +845,53 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `id` - `Identifier of the function. /// * `return_types` - Types of the return values of the function /// * `param_expressions` - Arguments of this call - fn flatten_function_call( + fn flatten_embed_call( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, - id: FunctionIdentifier<'ast>, - return_types: Vec, + embed: FlatEmbed, + generics: Vec, param_expressions: Vec>, ) -> Vec> { - let passed_signature = Signature::new() - .inputs(param_expressions.iter().map(|e| e.get_type()).collect()) - .outputs(return_types); - - let key = FunctionKey::with_id(id).signature(passed_signature); - - let funct = self.get_embed(&key, &symbols); - - match funct { + match embed { crate::embed::FlatEmbed::U32ToBits => self.flatten_u_to_bits( - symbols, statements_flattened, param_expressions[0].clone(), 32.into(), ), crate::embed::FlatEmbed::U16ToBits => self.flatten_u_to_bits( - symbols, statements_flattened, param_expressions[0].clone(), 16.into(), ), - crate::embed::FlatEmbed::U8ToBits => self.flatten_u_to_bits( - symbols, - statements_flattened, - param_expressions[0].clone(), - 8.into(), - ), - crate::embed::FlatEmbed::U32FromBits => vec![self.flatten_bits_to_u( - symbols, - statements_flattened, - param_expressions, - 32.into(), - )], - crate::embed::FlatEmbed::U16FromBits => vec![self.flatten_bits_to_u( - symbols, - statements_flattened, - param_expressions, - 16.into(), - )], - crate::embed::FlatEmbed::U8FromBits => vec![self.flatten_bits_to_u( - symbols, - statements_flattened, - param_expressions, - 8.into(), - )], + crate::embed::FlatEmbed::U8ToBits => { + self.flatten_u_to_bits(statements_flattened, param_expressions[0].clone(), 8.into()) + } + crate::embed::FlatEmbed::U32FromBits => { + vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 32.into())] + } + crate::embed::FlatEmbed::U16FromBits => { + vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 16.into())] + } + crate::embed::FlatEmbed::U8FromBits => { + vec![self.flatten_bits_to_u(statements_flattened, param_expressions, 8.into())] + } funct => { - let funct = funct.synthetize(); + let funct = funct.synthetize(&generics); let mut replacement_map = HashMap::new(); // Handle complex parameters and assign values: // Rename Parameters, assign them to values in call. Resolve complex expressions with definitions + // Clippy doesn't like the fact that we're collecting here, however not doing so leads to a borrow issue + // of `self` in the for-loop just after. This is why the `needless_collect` lint is disabled for this file + // (it does not work for this single line) let params_flattened = param_expressions .into_iter() - .map(|param_expr| { - self.flatten_expression(symbols, statements_flattened, param_expr) - }) + .map(|param_expr| self.flatten_expression(statements_flattened, param_expr)) .into_iter() .map(|x| x.get_field_unchecked()) .collect::>(); @@ -863,11 +906,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { // Ensure renaming and correct returns: // add all flattened statements, adapt return statements - let (mut return_statements, statements): (Vec<_>, Vec<_>) = - funct.statements.into_iter().partition(|s| match s { - FlatStatement::Return(..) => true, - _ => false, - }); + let (mut return_statements, statements): (Vec<_>, Vec<_>) = funct + .statements + .into_iter() + .partition(|s| matches!(s, FlatStatement::Return(..))); let statements: Vec<_> = statements .into_iter() @@ -916,7 +958,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .expressions .into_iter() .map(|x| x.apply_substitution(&replacement_map)) - .map(|x| FlatUExpression::with_field(x)) + .map(FlatUExpression::with_field) .collect(), _ => unreachable!(), } @@ -928,37 +970,32 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expr` - `ZirExpression` that will be flattened. fn flatten_expression( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, expr: ZirExpression<'ast, T>, ) -> FlatUExpression { match expr { - ZirExpression::FieldElement(e) => FlatUExpression::with_field( - self.flatten_field_expression(symbols, statements_flattened, e), - ), - ZirExpression::Boolean(e) => FlatUExpression::with_field( - self.flatten_boolean_expression(symbols, statements_flattened, e), - ), - ZirExpression::Uint(e) => { - self.flatten_uint_expression(symbols, statements_flattened, e) + ZirExpression::FieldElement(e) => { + FlatUExpression::with_field(self.flatten_field_expression(statements_flattened, e)) } + ZirExpression::Boolean(e) => FlatUExpression::with_field( + self.flatten_boolean_expression(statements_flattened, e), + ), + ZirExpression::Uint(e) => self.flatten_uint_expression(statements_flattened, e), } } fn default_xor( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, left: UExpression<'ast, T>, right: UExpression<'ast, T>, ) -> FlatUExpression { - let left_flattened = self.flatten_uint_expression(symbols, statements_flattened, left); - let right_flattened = self.flatten_uint_expression(symbols, statements_flattened, right); + let left_flattened = self.flatten_uint_expression(statements_flattened, left); + let right_flattened = self.flatten_uint_expression(statements_flattened, right); // `left` and `right` were reduced to the target bitwidth, hence their bits are available @@ -990,7 +1027,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.extend(vec![ FlatStatement::Directive(FlatDirective::new( - vec![name.clone()], + vec![name], Solver::Xor, vec![x.clone(), y.clone()], )), @@ -1016,17 +1053,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn euclidean_division( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, target_bitwidth: UBitwidth, left: UExpression<'ast, T>, right: UExpression<'ast, T>, ) -> (FlatExpression, FlatExpression) { let left_flattened = self - .flatten_uint_expression(symbols, statements_flattened, left) + .flatten_uint_expression(statements_flattened, left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(symbols, statements_flattened, right) + .flatten_uint_expression(statements_flattened, right) .get_field_unchecked(); let n = if left_flattened.is_linear() { left_flattened @@ -1056,7 +1092,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // assert(invd * d == 1) statements_flattened.push(FlatStatement::Condition( FlatExpression::Number(T::one()), - FlatExpression::Mult(box invd.into(), box d.clone().into()), + FlatExpression::Mult(box invd.into(), box d.clone()), )); // now introduce the quotient and remainder @@ -1065,7 +1101,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.push(FlatStatement::Directive(FlatDirective { inputs: vec![n.clone(), d.clone()], - outputs: vec![q.clone(), r.clone()], + outputs: vec![q, r], solver: Solver::EuclideanDiv, })); @@ -1088,7 +1124,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // r < d <=> r - d + 2**w < 2**w let _ = self.get_bits( FlatUExpression::with_field(FlatExpression::Add( - box FlatExpression::Sub(box r.into(), box d.clone().into()), + box FlatExpression::Sub(box r.into(), box d.clone()), box FlatExpression::Number(T::from(2usize.pow(target_bitwidth.to_usize() as u32))), )), target_bitwidth.to_usize(), @@ -1109,19 +1145,17 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expr` - `UExpression` that will be flattened. fn flatten_uint_expression( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, expr: UExpression<'ast, T>, ) -> FlatUExpression { // the bitwidth for this type of uint (8, 16 or 32) let target_bitwidth = expr.bitwidth; - let metadata = expr.metadata.clone().unwrap().clone(); + let metadata = expr.metadata.clone().unwrap(); // the bitwidth on which this value is currently represented let actual_bitwidth = metadata.bitwidth() as usize; @@ -1136,7 +1170,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_field(FlatExpression::Number(T::from(x as usize))) } // force to be a field element UExpressionInner::Identifier(x) => { - let field = FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()); + let field = FlatExpression::Identifier(*self.layout.get(&x).unwrap()); let bits = self.bits_cache.get(&field).map(|bits| { assert_eq!(bits.len(), target_bitwidth.to_usize()); bits.clone() @@ -1144,7 +1178,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_field(field).bits(bits) } UExpressionInner::Not(box e) => { - let e = self.flatten_uint_expression(symbols, statements_flattened, e); + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1165,11 +1199,11 @@ impl<'ast, T: Field> Flattener<'ast, T> { } UExpressionInner::Add(box left, box right) => { let left_flattened = self - .flatten_uint_expression(symbols, statements_flattened, left) + .flatten_uint_expression(statements_flattened, left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(symbols, statements_flattened, right) + .flatten_uint_expression(statements_flattened, right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { @@ -1196,10 +1230,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { ); let left_flattened = self - .flatten_uint_expression(symbols, statements_flattened, left) + .flatten_uint_expression(statements_flattened, left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(symbols, statements_flattened, right) + .flatten_uint_expression(statements_flattened, right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened @@ -1233,7 +1267,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), }; - let e = self.flatten_uint_expression(symbols, statements_flattened, e); + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1262,7 +1296,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), }; - let e = self.flatten_uint_expression(symbols, statements_flattened, e); + let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1280,10 +1314,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { } UExpressionInner::Mult(box left, box right) => { let left_flattened = self - .flatten_uint_expression(symbols, statements_flattened, left) + .flatten_uint_expression(statements_flattened, left) .get_field_unchecked(); let right_flattened = self - .flatten_uint_expression(symbols, statements_flattened, right) + .flatten_uint_expression(statements_flattened, right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened @@ -1310,36 +1344,24 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_field(FlatExpression::Identifier(res)) } UExpressionInner::Div(box left, box right) => { - let (q, _) = self.euclidean_division( - symbols, - statements_flattened, - target_bitwidth, - left, - right, - ); + let (q, _) = + self.euclidean_division(statements_flattened, target_bitwidth, left, right); FlatUExpression::with_field(q) } UExpressionInner::Rem(box left, box right) => { - let (_, r) = self.euclidean_division( - symbols, - statements_flattened, - target_bitwidth, - left, - right, - ); + let (_, r) = + self.euclidean_division(statements_flattened, target_bitwidth, left, right); FlatUExpression::with_field(r) } UExpressionInner::IfElse(box condition, box consequence, box alternative) => self .flatten_if_else_expression( - symbols, statements_flattened, condition, consequence, alternative, - ) - .clone(), + ), UExpressionInner::Xor(box left, box right) => { let left_metadata = left.metadata.clone().unwrap(); let right_metadata = right.metadata.clone().unwrap(); @@ -1347,12 +1369,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { match (left.into_inner(), right.into_inner()) { (UExpressionInner::And(box a, box b), UExpressionInner::And(box aa, box c)) => { if aa.clone().into_inner() == UExpressionInner::Not(box a.clone()) { - let a_flattened = - self.flatten_uint_expression(symbols, statements_flattened, a); - let b_flattened = - self.flatten_uint_expression(symbols, statements_flattened, b); - let c_flattened = - self.flatten_uint_expression(symbols, statements_flattened, c); + let a_flattened = self.flatten_uint_expression(statements_flattened, a); + let b_flattened = self.flatten_uint_expression(statements_flattened, b); + let c_flattened = self.flatten_uint_expression(statements_flattened, c); let a_bits = a_flattened.bits.unwrap(); let b_bits = b_flattened.bits.unwrap(); @@ -1369,7 +1388,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.extend(vec![ FlatStatement::Directive(FlatDirective::new( - vec![ch.clone()], + vec![ch], Solver::ShaCh, vec![a.clone(), b.clone(), c.clone()], )), @@ -1388,7 +1407,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(res) } else { self.default_xor( - symbols, statements_flattened, UExpressionInner::And(box a, box b) .annotate(target_bitwidth) @@ -1410,21 +1428,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { UExpressionInner::And(box bb, box cc), ) => { if (aa == a) && (bb == b) && (cc == c) { - let a_flattened = self.flatten_uint_expression( - symbols, - statements_flattened, - a, - ); - let b_flattened = self.flatten_uint_expression( - symbols, - statements_flattened, - b, - ); - let c_flattened = self.flatten_uint_expression( - symbols, - statements_flattened, - c, - ); + let a_flattened = + self.flatten_uint_expression(statements_flattened, a); + let b_flattened = + self.flatten_uint_expression(statements_flattened, b); + let c_flattened = + self.flatten_uint_expression(statements_flattened, c); let a_bits = a_flattened.bits.unwrap(); let b_bits = b_flattened.bits.unwrap(); @@ -1443,7 +1452,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.extend(vec![ FlatStatement::Directive(FlatDirective::new( - vec![maj.clone()], + vec![maj], Solver::ShaAndXorAndXorAnd, vec![a.clone(), b.clone(), c.clone()], )), @@ -1478,7 +1487,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(res) } else { self.default_xor( - symbols, statements_flattened, UExpressionInner::Xor( box UExpressionInner::And(box a, box b) @@ -1497,7 +1505,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } (a, b, c) => self.default_xor( - symbols, statements_flattened, UExpressionInner::Xor( box a.annotate(target_bitwidth).metadata(a_metadata), @@ -1510,7 +1517,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } (left_i, right_i) => self.default_xor( - symbols, statements_flattened, left_i.annotate(target_bitwidth).metadata(left_metadata), right_i.annotate(target_bitwidth).metadata(right_metadata), @@ -1518,11 +1524,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { } } UExpressionInner::And(box left, box right) => { - let left_flattened = - self.flatten_uint_expression(symbols, statements_flattened, left); + let left_flattened = self.flatten_uint_expression(statements_flattened, left); - let right_flattened = - self.flatten_uint_expression(symbols, statements_flattened, right); + let right_flattened = self.flatten_uint_expression(statements_flattened, right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); @@ -1555,10 +1559,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatUExpression::with_bits(and) } UExpressionInner::Or(box left, box right) => { - let left_flattened = - self.flatten_uint_expression(symbols, statements_flattened, left); - let right_flattened = - self.flatten_uint_expression(symbols, statements_flattened, right); + let left_flattened = self.flatten_uint_expression(statements_flattened, left); + let right_flattened = self.flatten_uint_expression(statements_flattened, right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); @@ -1587,7 +1589,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.extend(vec![ FlatStatement::Directive(FlatDirective::new( - vec![name.clone()], + vec![name], Solver::Or, vec![x.clone(), y.clone()], )), @@ -1630,7 +1632,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { box FlatExpression::Number( T::from(2).pow(target_bitwidth.to_usize() - index - 1), ), - box bit.clone().into(), + box bit.clone(), ), ) }, @@ -1660,23 +1662,19 @@ impl<'ast, T: Field> Flattener<'ast, T> { assert!(to < T::get_required_bits()); // constants do not require directives - match e.field { - Some(FlatExpression::Number(ref x)) => { - let solver = Solver::bits(to); - let bits: Vec<_> = solver - .execute(&vec![x.clone()]) - .unwrap() - .into_iter() - .map(|x| FlatExpression::Number(x)) - .collect(); + if let Some(FlatExpression::Number(ref x)) = e.field { + let bits: Vec<_> = Interpreter::default() + .execute_solver(&Solver::bits(to), &[x.clone()]) + .unwrap() + .into_iter() + .map(FlatExpression::Number) + .collect(); - assert_eq!(bits.len(), to); + assert_eq!(bits.len(), to); - self.bits_cache - .insert(e.field.clone().unwrap(), bits.clone()); - return bits; - } - _ => {} + self.bits_cache + .insert(e.field.clone().unwrap(), bits.clone()); + return bits; }; e.bits.clone().unwrap_or_else(|| { @@ -1687,7 +1685,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { let from = std::cmp::max(from, to); match self.bits_cache.entry(e.field.clone().unwrap()) { Entry::Occupied(entry) => { - let res: Vec<_> = entry.get().clone().into_iter().map(|e| e.into()).collect(); + let res: Vec<_> = entry.get().clone(); // if we already know a decomposition, it has to be of the size of the target bitwidth assert_eq!(res.len(), to); res @@ -1700,21 +1698,15 @@ impl<'ast, T: Field> Flattener<'ast, T> { vec![e.field.clone().unwrap()], ))); - let bits: Vec<_> = bits - .into_iter() - .map(|b| FlatExpression::Identifier(b)) - .collect(); + let bits: Vec<_> = bits.into_iter().map(FlatExpression::Identifier).collect(); // decompose to the actual bitwidth // bit checks - statements_flattened.extend((0..from).map(|i| { + statements_flattened.extend(bits.iter().take(from).map(|bit| { FlatStatement::Condition( - bits[i].clone(), - FlatExpression::Mult( - box bits[i].clone().into(), - box bits[i].clone().into(), - ), + bit.clone(), + FlatExpression::Mult(box bit.clone(), box bit.clone()), ) })); @@ -1734,7 +1726,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { self.bits_cache.insert(e.field.unwrap(), bits.clone()); self.bits_cache.insert(sum, bits.clone()); - bits.into_iter().map(|v| v.into()).collect() + bits } } }) @@ -1744,25 +1736,21 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expr` - `FieldElementExpression` that will be flattened. fn flatten_field_expression( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, expr: FieldElementExpression<'ast, T>, ) -> FlatExpression { match expr { FieldElementExpression::Number(x) => FlatExpression::Number(x), // force to be a field element FieldElementExpression::Identifier(x) => { - FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()) + FlatExpression::Identifier(*self.layout.get(&x).unwrap_or_else(|| panic!("{}", x))) } FieldElementExpression::Add(box left, box right) => { - let left_flattened = - self.flatten_field_expression(symbols, statements_flattened, left); - let right_flattened = - self.flatten_field_expression(symbols, statements_flattened, right); + let left_flattened = self.flatten_field_expression(statements_flattened, left); + let right_flattened = self.flatten_field_expression(statements_flattened, right); let new_left = if left_flattened.is_linear() { left_flattened } else { @@ -1780,10 +1768,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Add(box new_left, box new_right) } FieldElementExpression::Sub(box left, box right) => { - let left_flattened = - self.flatten_field_expression(symbols, statements_flattened, left); - let right_flattened = - self.flatten_field_expression(symbols, statements_flattened, right); + let left_flattened = self.flatten_field_expression(statements_flattened, left); + let right_flattened = self.flatten_field_expression(statements_flattened, right); let new_left = if left_flattened.is_linear() { left_flattened @@ -1803,10 +1789,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Sub(box new_left, box new_right) } FieldElementExpression::Mult(box left, box right) => { - let left_flattened = - self.flatten_field_expression(symbols, statements_flattened, left); - let right_flattened = - self.flatten_field_expression(symbols, statements_flattened, right); + let left_flattened = self.flatten_field_expression(statements_flattened, left); + let right_flattened = self.flatten_field_expression(statements_flattened, right); let new_left = if left_flattened.is_linear() { left_flattened } else { @@ -1824,10 +1808,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { FlatExpression::Mult(box new_left, box new_right) } FieldElementExpression::Div(box left, box right) => { - let left_flattened = - self.flatten_field_expression(symbols, statements_flattened, left); - let right_flattened = - self.flatten_field_expression(symbols, statements_flattened, right); + let left_flattened = self.flatten_field_expression(statements_flattened, left); + let right_flattened = self.flatten_field_expression(statements_flattened, right); let new_left: FlatExpression = { let id = self.use_sym(); statements_flattened.push(FlatStatement::Definition(id, left_flattened)); @@ -1852,7 +1834,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // assert(invb * b == 1) statements_flattened.push(FlatStatement::Condition( FlatExpression::Number(T::one()), - FlatExpression::Mult(box invb.into(), box new_right.clone().into()), + FlatExpression::Mult(box invb.into(), box new_right.clone()), )); // # c = a/b @@ -1864,28 +1846,23 @@ impl<'ast, T: Field> Flattener<'ast, T> { // assert(c * b == a) statements_flattened.push(FlatStatement::Condition( - new_left.into(), + new_left, FlatExpression::Mult(box new_right, box inverse.into()), )); inverse.into() } FieldElementExpression::Pow(box base, box exponent) => { - match exponent { - FieldElementExpression::Number(ref e) => { + match exponent.into_inner() { + UExpressionInner::Value(ref e) => { // flatten the base expression - let base_flattened = self.flatten_field_expression( - symbols, - statements_flattened, - base.clone(), - ); + let base_flattened = + self.flatten_field_expression(statements_flattened, base.clone()); // we require from the base to be linear // TODO change that assert!(base_flattened.is_linear()); - let e = e.to_dec_string().parse::().unwrap(); - // convert the exponent to bytes, big endian let ebytes_be = e.to_be_bytes(); // convert the bytes to bits, remove leading zeroes (we only need powers up to the highest non-zero bit) @@ -1914,16 +1891,16 @@ impl<'ast, T: Field> Flattener<'ast, T> { let id = self.use_sym(); // set it to the square of the previous one, stored in state statements_flattened.push(FlatStatement::Definition( - id.clone(), + id, FlatExpression::Mult( box previous.clone(), box previous.clone(), ), )); // store it in the state for later squaring - *state = Some(FlatExpression::Identifier(id.clone())); + *state = Some(FlatExpression::Identifier(id)); // return it for later use constructing the result - Some(FlatExpression::Identifier(id.clone())) + Some(FlatExpression::Identifier(id)) } } }) @@ -1951,14 +1928,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { } FieldElementExpression::IfElse(box condition, box consequence, box alternative) => self .flatten_if_else_expression( - symbols, statements_flattened, condition, consequence, alternative, ) - .get_field_unchecked() - .clone(), + .get_field_unchecked(), } } @@ -1966,12 +1941,10 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// # Arguments /// - /// * `symbols` - Available functions in this context /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `stat` - `ZirStatement` that will be flattened. fn flatten_statement( &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, statements_flattened: &mut FlatStatements, stat: ZirStatement<'ast, T>, ) { @@ -1979,7 +1952,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { ZirStatement::Return(exprs) => { let flat_expressions = exprs .into_iter() - .map(|expr| self.flatten_expression(symbols, statements_flattened, expr)) + .map(|expr| self.flatten_expression(statements_flattened, expr)) .map(|x| x.get_field_unchecked()) .collect::>(); @@ -1989,13 +1962,12 @@ impl<'ast, T: Field> Flattener<'ast, T> { } ZirStatement::Declaration(_) => { // declarations have already been checked - () } ZirStatement::Definition(assignee, expr) => { // define n variables with n the number of primitive types for v_type // assign them to the n primitive types for expr - let rhs = self.flatten_expression(symbols, statements_flattened, expr); + let rhs = self.flatten_expression(statements_flattened, expr); let bits = rhs.bits.clone(); @@ -2015,12 +1987,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { }; // register bits - match bits { - Some(bits) => { - self.bits_cache - .insert(FlatExpression::Identifier(var), bits); - } - None => {} + if let Some(bits) = bits { + self.bits_cache + .insert(FlatExpression::Identifier(var), bits); } } ZirStatement::Assertion(e) => { @@ -2028,39 +1997,36 @@ impl<'ast, T: Field> Flattener<'ast, T> { BooleanExpression::And(..) => { for boolean in e.into_conjunction_iterator() { self.flatten_statement( - symbols, statements_flattened, ZirStatement::Assertion(boolean), ) } } BooleanExpression::FieldEq(box lhs, box rhs) => { - let lhs = self.flatten_field_expression(symbols, statements_flattened, lhs); - let rhs = self.flatten_field_expression(symbols, statements_flattened, rhs); + let lhs = self.flatten_field_expression(statements_flattened, lhs); + let rhs = self.flatten_field_expression(statements_flattened, rhs); self.flatten_equality(statements_flattened, lhs, rhs) } BooleanExpression::UintEq(box lhs, box rhs) => { let lhs = self - .flatten_uint_expression(symbols, statements_flattened, lhs) + .flatten_uint_expression(statements_flattened, lhs) .get_field_unchecked(); let rhs = self - .flatten_uint_expression(symbols, statements_flattened, rhs) + .flatten_uint_expression(statements_flattened, rhs) .get_field_unchecked(); self.flatten_equality(statements_flattened, lhs, rhs) } BooleanExpression::BoolEq(box lhs, box rhs) => { - let lhs = - self.flatten_boolean_expression(symbols, statements_flattened, lhs); - let rhs = - self.flatten_boolean_expression(symbols, statements_flattened, rhs); + let lhs = self.flatten_boolean_expression(statements_flattened, lhs); + let rhs = self.flatten_boolean_expression(statements_flattened, rhs); self.flatten_equality(statements_flattened, lhs, rhs) } _ => { // naive approach: flatten the boolean to a single field element and constrain it to 1 - let e = self.flatten_boolean_expression(symbols, statements_flattened, e); + let e = self.flatten_boolean_expression(statements_flattened, e); if e.is_linear() { statements_flattened.push(FlatStatement::Condition( @@ -2081,20 +2047,19 @@ impl<'ast, T: Field> Flattener<'ast, T> { // flatten the right side to p = sum(var_i.type.primitive_count) expressions // define p new variables to the right side expressions - let var_types = vars.iter().map(|v| v.get_type()).collect(); - match rhs { - ZirExpressionList::FunctionCall(key, exprs, _) => { - let rhs_flattened = self.flatten_function_call( - symbols, + ZirExpressionList::EmbedCall(embed, generics, exprs) => { + let rhs_flattened = self.flatten_embed_call( statements_flattened, - &key.id, - var_types, + embed, + generics, exprs.clone(), ); let rhs = rhs_flattened.into_iter(); + assert_eq!(vars.len(), rhs.len()); + let vars: Vec<_> = vars .into_iter() .zip(rhs) @@ -2111,15 +2076,20 @@ impl<'ast, T: Field> Flattener<'ast, T> { }) .collect(); - if ["_U32_FROM_BITS", "_U16_FROM_BITS", "_U8_FROM_BITS"].contains(&key.id) { - let bits = exprs - .into_iter() - .map(|e| { - self.flatten_expression(symbols, statements_flattened, e) - .get_field_unchecked() - }) - .collect(); - self.bits_cache.insert(vars[0].clone().into(), bits); + match embed { + FlatEmbed::U32FromBits + | FlatEmbed::U16FromBits + | FlatEmbed::U8FromBits => { + let bits = exprs + .into_iter() + .map(|e| { + self.flatten_expression(statements_flattened, e) + .get_field_unchecked() + }) + .collect(); + self.bits_cache.insert(vars[0].clone().into(), bits); + } + _ => {} } } } @@ -2130,13 +2100,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// Flattens a function /// /// # Arguments - /// * `symbols` - Available functions in in this context /// * `funct` - `ZirFunction` that will be flattened - fn flatten_function( - &mut self, - symbols: &ZirFunctionSymbols<'ast, T>, - funct: ZirFunction<'ast, T>, - ) -> FlatFunction { + fn flatten_function(&mut self, funct: ZirFunction<'ast, T>) -> FlatFunction { self.layout = HashMap::new(); self.next_var_idx = 0; @@ -2151,7 +2116,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { // flatten statements in functions and apply substitution for stat in funct.statements { - self.flatten_statement(symbols, &mut statements_flattened, stat); + self.flatten_statement(&mut statements_flattened, stat); } FlatFunction { @@ -2166,28 +2131,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { /// /// * `prog` - `ZirProgram` that will be flattened. fn flatten_program(&mut self, prog: ZirProgram<'ast, T>) -> FlatProg { - let mut prog = prog; - - let mut main_module = prog.modules.remove(&prog.main).unwrap(); - - let main_key = main_module - .functions - .keys() - .find(|k| k.id == "main") - .unwrap() - .clone(); - - let main = main_module.functions.remove(&main_key).unwrap(); - - let symbols = &main_module.functions; - - let main_flattened = match main { - ZirFunctionSymbol::Here(f) => self.flatten_function(&symbols, f), - _ => unreachable!("main should be a typed function locally"), - }; - FlatProg { - main: main_flattened, + main: self.flatten_function(prog.main), } } @@ -2251,7 +2196,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn use_variable(&mut self, variable: &Variable<'ast>) -> FlatVariable { let var = self.issue_new_variable(); - self.layout.insert(variable.id.clone(), var.clone()); + self.layout.insert(variable.id.clone(), var); var } @@ -2321,19 +2266,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { fn use_sym(&mut self) -> FlatVariable { self.issue_new_variable() } - - fn get_embed<'a>( - &mut self, - key: &'a FunctionKey<'ast>, - symbols: &'a ZirFunctionSymbols<'ast, T>, - ) -> crate::embed::FlatEmbed { - let f = symbols.get(&key).expect(&format!("{}", key.id)).clone(); - let res = match f { - ZirFunctionSymbol::Flat(flat_function) => flat_function, - _ => unreachable!("only local flat symbols can be flattened"), - }; - res - } } #[cfg(test)] @@ -2379,7 +2311,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2443,7 +2375,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2511,7 +2443,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2568,7 +2500,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2638,7 +2570,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2712,7 +2644,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2796,7 +2728,7 @@ mod tests { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let flat = flattener.flatten_function(&HashMap::new(), function); + let flat = flattener.flatten_function(function); let expected = FlatFunction { arguments: vec![], statements: vec![ @@ -2858,7 +2790,7 @@ mod tests { Variable::field_element("b"), FieldElementExpression::Pow( box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(0)), + box 0u32.into(), ) .into(), ), @@ -2890,7 +2822,7 @@ mod tests { ], }; - let flattened = flattener.flatten_function(&mut HashMap::new(), function); + let flattened = flattener.flatten_function(function); assert_eq!(flattened, expected); } @@ -2918,7 +2850,7 @@ mod tests { Variable::field_element("b"), FieldElementExpression::Pow( box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), + box 1u32.into(), ) .into(), ), @@ -2953,7 +2885,7 @@ mod tests { ], }; - let flattened = flattener.flatten_function(&mut HashMap::new(), function); + let flattened = flattener.flatten_function(function); assert_eq!(flattened, expected); } @@ -2998,7 +2930,7 @@ mod tests { Variable::field_element("b"), FieldElementExpression::Pow( box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(13)), + box 13u32.into(), ) .into(), ), @@ -3068,7 +3000,7 @@ mod tests { ], }; - let flattened = flattener.flatten_function(&mut HashMap::new(), function); + let flattened = flattener.flatten_function(function); assert_eq!(flattened, expected); } @@ -3087,33 +3019,25 @@ mod tests { let mut flattener = Flattener::new(&config); - flattener.flatten_field_expression(&HashMap::new(), &mut FlatStatements::new(), expression); + flattener.flatten_field_expression(&mut FlatStatements::new(), expression); } #[test] fn geq_leq() { let config = CompileConfig::default(); let mut flattener = Flattener::new(&config); - let expression_le = BooleanExpression::Le( + let expression_le = BooleanExpression::FieldLe( box FieldElementExpression::Number(Bn128Field::from(32)), box FieldElementExpression::Number(Bn128Field::from(4)), ); - flattener.flatten_boolean_expression( - &HashMap::new(), - &mut FlatStatements::new(), - expression_le, - ); + flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le); let mut flattener = Flattener::new(&config); - let expression_ge = BooleanExpression::Ge( + let expression_ge = BooleanExpression::FieldGe( box FieldElementExpression::Number(Bn128Field::from(32)), box FieldElementExpression::Number(Bn128Field::from(4)), ); - flattener.flatten_boolean_expression( - &HashMap::new(), - &mut FlatStatements::new(), - expression_ge, - ); + flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_ge); } #[test] @@ -3127,7 +3051,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(4)), ), - box BooleanExpression::Lt( + box BooleanExpression::FieldLt( box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(20)), ), @@ -3136,7 +3060,7 @@ mod tests { box FieldElementExpression::Number(Bn128Field::from(51)), ); - flattener.flatten_field_expression(&HashMap::new(), &mut FlatStatements::new(), expression); + flattener.flatten_field_expression(&mut FlatStatements::new(), expression); } #[test] @@ -3163,9 +3087,9 @@ mod tests { .into(), ); - flattener.flatten_statement(&HashMap::new(), &mut statements_flattened, definition); + flattener.flatten_statement(&mut statements_flattened, definition); - flattener.flatten_statement(&HashMap::new(), &mut statements_flattened, statement); + flattener.flatten_statement(&mut statements_flattened, statement); // define b let b = FlatVariable::new(0); diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index 0f5c622d..a3775c38 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -42,7 +42,7 @@ impl fmt::Display for Error { let location = self .pos .map(|p| format!("{}", p.0)) - .unwrap_or("?".to_string()); + .unwrap_or_else(|| "?".to_string()); write!(f, "{}\n\t{}", location, self.message) } } @@ -125,15 +125,10 @@ impl<'ast> fmt::Debug for Import<'ast> { } } -pub struct Importer {} +pub struct Importer; impl Importer { - pub fn new() -> Importer { - Importer {} - } - pub fn apply_imports<'ast, T: Field, E: Into>( - &self, destination: Module<'ast>, location: PathBuf, resolver: Option<&dyn Resolver>, @@ -179,7 +174,7 @@ impl Importer { symbols.push( SymbolDeclaration { id: &alias, - symbol: Symbol::Flat(FlatEmbed::Unpack(T::get_required_bits())), + symbol: Symbol::Flat(FlatEmbed::Unpack), } .start_end(pos.0, pos.1), ); @@ -267,13 +262,15 @@ impl Importer { let alias = import.alias.unwrap_or( std::path::Path::new(import.source) .file_stem() - .ok_or(CompileErrors::from( - CompileErrorInner::ImportError(Error::new(format!( - "Could not determine alias for import {}", - import.source.display() - ))) - .in_file(&location), - ))? + .ok_or_else(|| { + CompileErrors::from( + CompileErrorInner::ImportError(Error::new(format!( + "Could not determine alias for import {}", + import.source.display() + ))) + .in_file(&location), + ) + })? .to_str() .unwrap(), ); @@ -335,7 +332,6 @@ impl Importer { Ok(Module { imports: vec![], symbols, - ..destination }) } } diff --git a/zokrates_core/src/ir/expression.rs b/zokrates_core/src/ir/expression.rs index 3a426141..57ec2753 100644 --- a/zokrates_core/src/ir/expression.rs +++ b/zokrates_core/src/ir/expression.rs @@ -2,49 +2,36 @@ use crate::flat_absy::FlatVariable; use serde::{Deserialize, Serialize}; use std::collections::btree_map::{BTreeMap, Entry}; use std::fmt; +use std::hash::Hash; use std::ops::{Add, Div, Mul, Sub}; use zokrates_field::Field; -#[derive(Debug, Clone, Serialize, Deserialize, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] pub struct QuadComb { pub left: LinComb, pub right: LinComb, } -impl PartialEq for QuadComb { - fn eq(&self, other: &Self) -> bool { - self.left.eq(&other.left) && self.right.eq(&other.right) - } -} - -impl Eq for QuadComb {} - impl QuadComb { pub fn from_linear_combinations(left: LinComb, right: LinComb) -> Self { QuadComb { left, right } } - pub fn try_linear(&self) -> Option> { - // identify (k * ~ONE) * (lincomb) and return (k * lincomb) - - match self.left.try_summand() { - Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => { - return Some(self.right.clone() * &coefficient); - } - _ => {} - } - match self.right.try_summand() { - Some((ref variable, ref coefficient)) if *variable == FlatVariable::one() => { - return Some(self.left.clone() * &coefficient); - } - _ => {} - } + pub fn try_linear(self) -> Result, Self> { + // identify `(k * ~ONE) * (lincomb)` and `(lincomb) * (k * ~ONE)` and return (k * lincomb) + // if not, error out with the input if self.left.is_zero() || self.right.is_zero() { - return Some(LinComb::zero()); + return Ok(LinComb::zero()); } - None + match self.left.try_constant() { + Ok(coefficient) => Ok(self.right * &coefficient), + Err(left) => match self.right.try_constant() { + Ok(coefficient) => Ok(left * &coefficient), + Err(right) => Err(QuadComb::from_linear_combinations(left, right)), + }, + } } } @@ -66,17 +53,9 @@ impl fmt::Display for QuadComb { } } -#[derive(Clone, Hash, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] pub struct LinComb(pub Vec<(FlatVariable, T)>); -impl PartialEq for LinComb { - fn eq(&self, other: &Self) -> bool { - self.clone().into_canonical() == other.clone().into_canonical() - } -} - -impl Eq for LinComb {} - #[derive(PartialEq, PartialOrd, Clone, Eq, Ord, Hash, Debug, Serialize, Deserialize)] pub struct CanonicalLinComb(pub BTreeMap); @@ -113,36 +92,52 @@ impl LinComb { } pub fn is_zero(&self) -> bool { - self.0.len() == 0 + self.0.is_empty() } } impl LinComb { - pub fn try_summand(&self) -> Option<(FlatVariable, T)> { + pub fn try_constant(self) -> Result { match self.0.len() { - // if the lincomb is empty, it is not reduceable to a summand - 0 => None, + // if the lincomb is empty, it is reduceable to 0 + 0 => Ok(T::zero()), _ => { // take the first variable in the lincomb let first = &self.0[0].0; - self.0 - .iter() - .map(|element| { + if first != &FlatVariable::one() { + return Err(self); + } + + // all terms must contain the same variable + if self.0.iter().all(|element| element.0 == *first) { + Ok(self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1)) + } else { + Err(self) + } + } + } + } + + pub fn try_summand(self) -> Result<(FlatVariable, T), Self> { + match self.0.len() { + // if the lincomb is empty, it is not reduceable to a summand + 0 => Err(self), + _ => { + // take the first variable in the lincomb + let first = &self.0[0].0; + + if self.0.iter().all(|element| // all terms must contain the same variable - if element.0 == *first { - // if they do, return the coefficient - Ok(&element.1) - } else { - // otherwise, stop - Err(()) - } - }) - // collect to a Result to short circuit when we hit an error - .collect::>() - // we didn't hit an error, do final processing. It's fine to clone here. - .map(|v: Vec<_>| (first.clone(), v.iter().fold(T::zero(), |acc, e| acc + *e))) - .ok() + element.0 == *first) + { + Ok(( + *first, + self.0.into_iter().fold(T::zero(), |acc, e| acc + e.1), + )) + } else { + Err(self) + } } } } @@ -207,9 +202,7 @@ impl fmt::Display for LinComb { false => write!( f, "{}", - self.clone() - .into_canonical() - .0 + self.0 .iter() .map(|(k, v)| format!("{} * {}", v.to_compact_dec_string(), k)) .collect::>() @@ -251,10 +244,14 @@ impl Mul<&T> for LinComb { type Output = LinComb; fn mul(self, scalar: &T) -> LinComb { + if scalar == &T::one() { + return self; + } + LinComb( self.0 .into_iter() - .map(|(var, coeff)| (var, coeff * scalar)) + .map(|(var, coeff)| (var, coeff * scalar.clone())) .collect(), ) } @@ -262,7 +259,8 @@ impl Mul<&T> for LinComb { impl Div<&T> for LinComb { type Output = LinComb; - + // Clippy warns about multiplication in a method named div. It's okay, here, since we multiply with the inverse. + #[allow(clippy::suspicious_arithmetic_impl)] fn div(self, scalar: &T) -> LinComb { self * &scalar.inverse_mul().unwrap() } @@ -287,7 +285,7 @@ mod tests { fn add() { let a: LinComb = FlatVariable::new(42).into(); let b: LinComb = FlatVariable::new(42).into(); - let c = a + b.clone(); + let c = a + b; let expected_vec = vec![ (FlatVariable::new(42), Bn128Field::from(1)), @@ -300,7 +298,7 @@ mod tests { fn sub() { let a: LinComb = FlatVariable::new(42).into(); let b: LinComb = FlatVariable::new(42).into(); - let c = a - b.clone(); + let c = a - b; let expected_vec = vec![ (FlatVariable::new(42), Bn128Field::from(1)), @@ -314,7 +312,7 @@ mod tests { fn display() { let a: LinComb = LinComb::from(FlatVariable::new(42)) + LinComb::summand(3, FlatVariable::new(21)); - assert_eq!(&a.to_string(), "3 * _21 + 1 * _42"); + assert_eq!(&a.to_string(), "1 * _42 + 3 * _21"); let zero: LinComb = LinComb::zero(); assert_eq!(&zero.to_string(), "0"); } @@ -350,7 +348,7 @@ mod tests { + LinComb::summand(4, FlatVariable::new(33)), right: LinComb::summand(1, FlatVariable::new(21)), }; - assert_eq!(&a.to_string(), "(4 * _33 + 3 * _42) * (1 * _21)"); + assert_eq!(&a.to_string(), "(3 * _42 + 4 * _33) * (1 * _21)"); let a: QuadComb = QuadComb { left: LinComb::zero(), right: LinComb::summand(1, FlatVariable::new(21)), @@ -371,7 +369,7 @@ mod tests { ]); assert_eq!( summand.try_summand(), - Some((FlatVariable::new(42), Bn128Field::from(6))) + Ok((FlatVariable::new(42), Bn128Field::from(6))) ); let not_summand = LinComb(vec![ @@ -379,10 +377,10 @@ mod tests { (FlatVariable::new(42), Bn128Field::from(2)), (FlatVariable::new(42), Bn128Field::from(3)), ]); - assert_eq!(not_summand.try_summand(), None); + assert!(not_summand.try_summand().is_err()); let empty: LinComb = LinComb(vec![]); - assert_eq!(empty.try_summand(), None); + assert!(empty.try_summand().is_err()); } } } diff --git a/zokrates_core/src/ir/from_flat.rs b/zokrates_core/src/ir/from_flat.rs index 90b202da..0f6b2d1c 100644 --- a/zokrates_core/src/ir/from_flat.rs +++ b/zokrates_core/src/ir/from_flat.rs @@ -125,7 +125,7 @@ impl From> for Directive { inputs: ds .inputs .into_iter() - .map(|i| QuadComb::from_flat_expression(i)) + .map(QuadComb::from_flat_expression) .collect(), solver: ds.solver, outputs: ds.outputs, diff --git a/zokrates_core/src/ir/interpreter.rs b/zokrates_core/src/ir/interpreter.rs index 183c30d6..69ffacf1 100644 --- a/zokrates_core/src/ir/interpreter.rs +++ b/zokrates_core/src/ir/interpreter.rs @@ -1,10 +1,13 @@ use crate::flat_absy::flat_variable::FlatVariable; use crate::ir::Directive; use crate::ir::{LinComb, Prog, QuadComb, Statement, Witness}; -use crate::solvers::{Executable, Solver}; +use crate::solvers::Solver; +use pairing_ce::bn256::Bn256; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::fmt; +#[cfg(feature = "bellman")] +use zokrates_embed::generate_sha256_round_witness; use zokrates_field::Field; pub type ExecutionResult = Result, Error>; @@ -34,13 +37,13 @@ impl Interpreter { } impl Interpreter { - pub fn execute(&self, program: &Prog, inputs: &Vec) -> ExecutionResult { + pub fn execute(&self, program: &Prog, inputs: &[T]) -> ExecutionResult { let main = &program.main; self.check_inputs(&program, &inputs)?; let mut witness = BTreeMap::new(); witness.insert(FlatVariable::one(), T::one()); for (arg, value) in main.arguments.iter().zip(inputs.iter()) { - witness.insert(arg.clone(), value.clone().into()); + witness.insert(*arg, value.clone()); } for statement in main.statements.iter() { @@ -48,7 +51,7 @@ impl Interpreter { Statement::Constraint(quad, lin) => match lin.is_assignee(&witness) { true => { let val = quad.evaluate(&witness).unwrap(); - witness.insert(lin.0.iter().next().unwrap().0.clone(), val); + witness.insert(lin.0.get(0).unwrap().0, val); } false => { let lhs_value = quad.evaluate(&witness).unwrap(); @@ -76,10 +79,10 @@ impl Interpreter { .iter() .map(|i| i.evaluate(&witness).unwrap()) .collect(); - match d.solver.execute(&inputs) { + match self.execute_solver(&d.solver, &inputs) { Ok(res) => { for (i, o) in d.outputs.iter().enumerate() { - witness.insert(o.clone(), res[i].clone()); + witness.insert(*o, res[i].clone()); } continue; } @@ -107,12 +110,12 @@ impl Interpreter { value.to_biguint() }; - let mut num = input.clone(); + let mut num = input; let mut res = vec![]; let bits = T::get_required_bits(); for i in (0..bits).rev() { if T::from(2).to_biguint().pow(i as usize) <= num { - num = num - T::from(2).to_biguint().pow(i as usize); + num -= T::from(2).to_biguint().pow(i as usize); res.push(T::one()); } else { res.push(T::zero()); @@ -120,11 +123,11 @@ impl Interpreter { } assert_eq!(num, T::zero().to_biguint()); for (i, o) in d.outputs.iter().enumerate() { - witness.insert(o.clone(), res[i].clone()); + witness.insert(*o, res[i].clone()); } } - fn check_inputs(&self, program: &Prog, inputs: &Vec) -> Result<(), Error> { + fn check_inputs(&self, program: &Prog, inputs: &[U]) -> Result<(), Error> { if program.main.arguments.len() == inputs.len() { Ok(()) } else { @@ -134,26 +137,136 @@ impl Interpreter { }) } } + + pub fn execute_solver( + &self, + solver: &Solver, + inputs: &[T], + ) -> Result, String> { + let (expected_input_count, expected_output_count) = solver.get_signature(); + assert!(inputs.len() == expected_input_count); + + let res = match solver { + Solver::ConditionEq => match inputs[0].is_zero() { + true => vec![T::zero(), T::one()], + false => vec![ + T::one(), + T::one().checked_div(&inputs[0]).unwrap_or_else(T::one), + ], + }, + Solver::Bits(bit_width) => { + let mut num = inputs[0].clone(); + let mut res = vec![]; + + for i in (0..*bit_width).rev() { + if T::from(2).pow(i) <= num { + num = num - T::from(2).pow(i); + res.push(T::one()); + } else { + res.push(T::zero()); + } + } + res + } + Solver::Xor => { + let x = inputs[0].clone(); + let y = inputs[1].clone(); + + vec![x.clone() + y.clone() - T::from(2) * x * y] + } + Solver::Or => { + let x = inputs[0].clone(); + let y = inputs[1].clone(); + + vec![x.clone() + y.clone() - x * y] + } + // res = b * c - (2b * c - b - c) * (a) + Solver::ShaAndXorAndXorAnd => { + let a = inputs[0].clone(); + let b = inputs[1].clone(); + let c = inputs[2].clone(); + vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a] + } + // res = a(b - c) + c + Solver::ShaCh => { + let a = inputs[0].clone(); + let b = inputs[1].clone(); + let c = inputs[2].clone(); + vec![a * (b - c.clone()) + c] + } + + Solver::Div => vec![inputs[0] + .clone() + .checked_div(&inputs[1]) + .unwrap_or_else(T::one)], + Solver::EuclideanDiv => { + use num::CheckedDiv; + + let n = inputs[0].clone().to_biguint(); + let d = inputs[1].clone().to_biguint(); + + let q = n.checked_div(&d).unwrap_or_else(|| 0u32.into()); + let r = n - d * &q; + vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()] + } + #[cfg(feature = "bellman")] + Solver::Sha256Round => { + use zokrates_field::Bn128Field; + assert_eq!(T::id(), Bn128Field::id()); + let i = &inputs[0..512]; + let h = &inputs[512..]; + let to_fr = |x: &T| { + use pairing_ce::ff::{PrimeField, ScalarEngine}; + let s = x.to_dec_string(); + ::Fr::from_str(&s).unwrap() + }; + let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect(); + let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect(); + assert_eq!(h.len(), 256); + generate_sha256_round_witness::(&i, &h) + .into_iter() + .map(|x| { + use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr}; + let mut res: Vec = vec![]; + x.into_repr().write_le(&mut res).unwrap(); + T::from_byte_vector(res) + }) + .collect() + } + }; + + assert_eq!(res.len(), expected_output_count); + + Ok(res) + } } +#[derive(Debug)] +pub struct EvaluationError; + impl LinComb { - fn evaluate(&self, witness: &BTreeMap) -> Result { + fn evaluate(&self, witness: &BTreeMap) -> Result { self.0 .iter() - .map(|(var, mult)| witness.get(var).map(|v| v.clone() * mult).ok_or(())) // get each term + .map(|(var, mult)| { + witness + .get(var) + .map(|v| v.clone() * mult) + .ok_or(EvaluationError) + }) // get each term .collect::, _>>() // fail if any term isn't found .map(|v| v.iter().fold(T::from(0), |acc, t| acc + t)) // return the sum } fn is_assignee(&self, witness: &BTreeMap) -> bool { self.0.iter().count() == 1 - && self.0.iter().next().unwrap().1 == T::from(1) - && !witness.contains_key(&self.0.iter().next().unwrap().0) + && self.0.get(0).unwrap().1 == T::from(1) + && !witness.contains_key(&self.0.get(0).unwrap().0) } } impl QuadComb { - pub fn evaluate(&self, witness: &BTreeMap) -> Result { + pub fn evaluate(&self, witness: &BTreeMap) -> Result { let left = self.left.evaluate(&witness)?; let right = self.right.evaluate(&witness)?; Ok(left * right) @@ -192,3 +305,83 @@ impl fmt::Debug for Error { write!(f, "{}", self) } } + +#[cfg(test)] +mod tests { + use super::*; + use zokrates_field::Bn128Field; + + mod eq_condition { + + // Wanted: (Y = (X != 0) ? 1 : 0) + // # Y = if X == 0 then 0 else 1 fi + // # M = if X == 0 then 1 else 1/X fi + + use super::*; + + #[test] + fn execute() { + let cond_eq = Solver::ConditionEq; + let inputs = vec![0]; + let interpreter = Interpreter::default(); + let r = interpreter + .execute_solver( + &cond_eq, + &inputs + .iter() + .map(|&i| Bn128Field::from(i)) + .collect::>(), + ) + .unwrap(); + let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); + assert_eq!(r, &res[..]); + } + + #[test] + fn execute_non_eq() { + let cond_eq = Solver::ConditionEq; + let inputs = vec![1]; + let interpreter = Interpreter::default(); + let r = interpreter + .execute_solver( + &cond_eq, + &inputs + .iter() + .map(|&i| Bn128Field::from(i)) + .collect::>(), + ) + .unwrap(); + let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); + assert_eq!(r, &res[..]); + } + } + + #[test] + fn bits_of_one() { + let inputs = vec![Bn128Field::from(1)]; + let interpreter = Interpreter::default(); + let res = interpreter + .execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) + .unwrap(); + assert_eq!(res[253], Bn128Field::from(1)); + for i in 0..253 { + assert_eq!(res[i], Bn128Field::from(0)); + } + } + + #[test] + fn bits_of_42() { + let inputs = vec![Bn128Field::from(42)]; + let interpreter = Interpreter::default(); + let res = interpreter + .execute_solver(&Solver::Bits(Bn128Field::get_required_bits()), &inputs) + .unwrap(); + assert_eq!(res[253], Bn128Field::from(0)); + assert_eq!(res[252], Bn128Field::from(1)); + assert_eq!(res[251], Bn128Field::from(0)); + assert_eq!(res[250], Bn128Field::from(1)); + assert_eq!(res[249], Bn128Field::from(0)); + assert_eq!(res[248], Bn128Field::from(1)); + assert_eq!(res[247], Bn128Field::from(0)); + } +} diff --git a/zokrates_core/src/ir/mod.rs b/zokrates_core/src/ir/mod.rs index c120806b..dfc97f76 100644 --- a/zokrates_core/src/ir/mod.rs +++ b/zokrates_core/src/ir/mod.rs @@ -3,6 +3,7 @@ use crate::flat_absy::FlatVariable; use crate::solvers::Solver; use serde::{Deserialize, Serialize}; use std::fmt; +use std::hash::Hash; use zokrates_field::Field; mod expression; @@ -19,26 +20,12 @@ pub use self::serialize::ProgEnum; pub use self::interpreter::{Error, ExecutionResult, Interpreter}; pub use self::witness::Witness; -#[derive(Debug, Serialize, Deserialize, Clone, Hash)] +#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)] pub enum Statement { Constraint(QuadComb, LinComb), Directive(Directive), } -impl PartialEq for Statement { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Statement::Constraint(l1, r1), Statement::Constraint(l2, r2)) => { - l1.eq(l2) && r1.eq(r2) - } - (Statement::Directive(d1), Statement::Directive(d2)) => d1.eq(d2), - _ => false, - } - } -} - -impl Eq for Statement {} - impl Statement { pub fn definition>>(v: FlatVariable, e: U) -> Self { Statement::Constraint(e.into(), v.into()) @@ -49,23 +36,13 @@ impl Statement { } } -#[derive(Clone, Debug, Serialize, Deserialize, Hash)] +#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)] pub struct Directive { pub inputs: Vec>, pub outputs: Vec, pub solver: Solver, } -impl PartialEq for Directive { - fn eq(&self, other: &Self) -> bool { - self.inputs.eq(&other.inputs) - && self.outputs.eq(&other.outputs) - && self.solver.eq(&other.solver) - } -} - -impl Eq for Directive {} - impl fmt::Display for Directive { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -95,7 +72,7 @@ impl fmt::Display for Statement { } } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)] pub struct Function { pub id: String, pub statements: Vec>, @@ -103,15 +80,6 @@ pub struct Function { pub returns: Vec, } -impl PartialEq for Function { - fn eq(&self, other: &Self) -> bool { - self.id.eq(&other.id) - && self.statements.eq(&other.statements) - && self.arguments.eq(&other.arguments) - && self.returns.eq(&other.returns) - } -} - impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( @@ -138,27 +106,18 @@ impl fmt::Display for Function { } } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)] pub struct Prog { pub main: Function, pub private: Vec, } -impl PartialEq for Prog { - fn eq(&self, other: &Self) -> bool { - self.main.eq(&other.main) && self.private.eq(&other.private) - } -} - impl Prog { pub fn constraint_count(&self) -> usize { self.main .statements .iter() - .filter(|s| match s { - Statement::Constraint(..) => true, - _ => false, - }) + .filter(|s| matches!(s, Statement::Constraint(..))) .count() } diff --git a/zokrates_core/src/ir/serialize.rs b/zokrates_core/src/ir/serialize.rs index 8b352f32..d97220a4 100644 --- a/zokrates_core/src/ir/serialize.rs +++ b/zokrates_core/src/ir/serialize.rs @@ -16,9 +16,9 @@ pub enum ProgEnum { impl Prog { pub fn serialize(&self, mut w: W) { - w.write(ZOKRATES_MAGIC).unwrap(); - w.write(ZOKRATES_VERSION_1).unwrap(); - w.write(&T::id()).unwrap(); + w.write_all(ZOKRATES_MAGIC).unwrap(); + w.write_all(ZOKRATES_VERSION_1).unwrap(); + w.write_all(&T::id()).unwrap(); serialize_into(&mut w, self, Infinite).unwrap(); } diff --git a/zokrates_core/src/macros.rs b/zokrates_core/src/macros.rs index ff36c1e7..a6d5b969 100644 --- a/zokrates_core/src/macros.rs +++ b/zokrates_core/src/macros.rs @@ -19,7 +19,7 @@ impl fmt::Display for Error { } } -pub fn process_macros<'ast, T: Field>(file: File<'ast>) -> Result, Error> { +pub fn process_macros(file: File) -> Result { match &file.pragma { Some(pragma) => { if T::name() != pragma.curve.name { diff --git a/zokrates_core/src/optimizer/canonicalizer.rs b/zokrates_core/src/optimizer/canonicalizer.rs new file mode 100644 index 00000000..69981b05 --- /dev/null +++ b/zokrates_core/src/optimizer/canonicalizer.rs @@ -0,0 +1,10 @@ +use crate::ir::{folder::Folder, LinComb}; +use zokrates_field::Field; + +pub struct Canonicalizer; + +impl Folder for Canonicalizer { + fn fold_linear_combination(&mut self, l: LinComb) -> LinComb { + l.into_canonical().into() + } +} diff --git a/zokrates_core/src/optimizer/directive.rs b/zokrates_core/src/optimizer/directive.rs index b2972046..d84a72b3 100644 --- a/zokrates_core/src/optimizer/directive.rs +++ b/zokrates_core/src/optimizer/directive.rs @@ -12,10 +12,10 @@ use crate::flat_absy::flat_variable::FlatVariable; use crate::ir::folder::*; use crate::ir::*; +use crate::optimizer::canonicalizer::Canonicalizer; use crate::solvers::Solver; use std::collections::hash_map::{Entry, HashMap}; use zokrates_field::Field; - #[derive(Debug)] pub struct DirectiveOptimizer { calls: HashMap<(Solver, Vec>), Vec>, @@ -37,6 +37,23 @@ impl DirectiveOptimizer { } impl Folder for DirectiveOptimizer { + fn fold_function(&mut self, f: Function) -> Function { + // in order to correcty identify duplicates, we need to first canonicalize the statements + + let mut canonicalizer = Canonicalizer; + + let f = Function { + statements: f + .statements + .into_iter() + .flat_map(|s| canonicalizer.fold_statement(s)) + .collect(), + ..f + }; + + fold_function(self, f) + } + fn fold_statement(&mut self, s: Statement) -> Vec> { match s { Statement::Directive(d) => { @@ -49,7 +66,7 @@ impl Folder for DirectiveOptimizer { } Entry::Occupied(e) => { self.substitution - .extend(d.outputs.into_iter().zip(e.get().into_iter().cloned())); + .extend(d.outputs.into_iter().zip(e.get().iter().cloned())); vec![] } } diff --git a/zokrates_core/src/optimizer/duplicate.rs b/zokrates_core/src/optimizer/duplicate.rs index f72b3a9f..f5a6aa96 100644 --- a/zokrates_core/src/optimizer/duplicate.rs +++ b/zokrates_core/src/optimizer/duplicate.rs @@ -1,7 +1,8 @@ //! Module containing the `DuplicateOptimizer` to remove duplicate constraints -use crate::ir::folder::Folder; +use crate::ir::folder::*; use crate::ir::*; +use crate::optimizer::canonicalizer::Canonicalizer; use std::collections::{hash_map::DefaultHasher, HashSet}; use zokrates_field::Field; @@ -33,6 +34,22 @@ impl DuplicateOptimizer { } impl Folder for DuplicateOptimizer { + fn fold_function(&mut self, f: Function) -> Function { + // in order to correcty identify duplicates, we need to first canonicalize the statements + let mut canonicalizer = Canonicalizer; + + let f = Function { + statements: f + .statements + .into_iter() + .flat_map(|s| canonicalizer.fold_statement(s)) + .collect(), + ..f + }; + + fold_function(self, f) + } + fn fold_statement(&mut self, s: Statement) -> Vec> { let hashed = hash(&s); let result = match self.seen.get(&hashed) { @@ -120,7 +137,7 @@ mod tests { main: Function { id: "main".to_string(), statements: vec![ - constraint.clone(), + constraint, Statement::Constraint( QuadComb::from_linear_combinations( LinComb::summand(3, FlatVariable::new(42)), diff --git a/zokrates_core/src/optimizer/mod.rs b/zokrates_core/src/optimizer/mod.rs index fb0e34ab..6fd1088a 100644 --- a/zokrates_core/src/optimizer/mod.rs +++ b/zokrates_core/src/optimizer/mod.rs @@ -4,6 +4,7 @@ //! @author Thibaut Schaeffer //! @date 2018 +mod canonicalizer; mod directive; mod duplicate; mod redefinition; @@ -26,7 +27,6 @@ impl Prog { // // deduplicate directives which take the same input let r = DirectiveOptimizer::optimize(r); // remove duplicate constraints - let r = DuplicateOptimizer::optimize(r); - r + DuplicateOptimizer::optimize(r) } } diff --git a/zokrates_core/src/optimizer/redefinition.rs b/zokrates_core/src/optimizer/redefinition.rs index c1994eec..467158de 100644 --- a/zokrates_core/src/optimizer/redefinition.rs +++ b/zokrates_core/src/optimizer/redefinition.rs @@ -40,7 +40,6 @@ use crate::flat_absy::flat_variable::FlatVariable; use crate::ir::folder::{fold_function, Folder}; use crate::ir::LinComb; use crate::ir::*; -use crate::solvers::Executable; use std::collections::{HashMap, HashSet}; use zokrates_field::Field; @@ -53,7 +52,7 @@ pub struct RedefinitionOptimizer { } impl RedefinitionOptimizer { - fn new() -> RedefinitionOptimizer { + fn new() -> Self { RedefinitionOptimizer { substitution: HashMap::new(), ignore: HashSet::new(), @@ -72,84 +71,77 @@ impl Folder for RedefinitionOptimizer { let quad = self.fold_quadratic_combination(quad); let lin = self.fold_linear_combination(lin); - let (keep_constraint, to_insert, to_ignore) = match lin.try_summand() { - // if the right side is a single variable - Some((variable, coefficient)) => { - match self.ignore.contains(&variable) { - // if the variable isn't tagged as ignored - false => match self.substitution.get(&variable) { - // if the variable is already defined - Some(_) => (true, None, None), - // if the variable is not defined yet - None => match quad.try_linear() { - // if the left side is linear - Some(l) => (false, Some((variable, l / &coefficient)), None), - // if the left side isn't linear - None => (true, None, Some(variable)), - }, - }, - true => (true, None, None), - } - } - None => (true, None, None), + if lin.is_zero() { + return vec![Statement::Constraint(quad, lin)]; + } + + let (constraint, to_insert, to_ignore) = match self.ignore.contains(&lin.0[0].0) + || self.substitution.contains_key(&lin.0[0].0) + { + true => (Some(Statement::Constraint(quad, lin)), None, None), + false => match lin.try_summand() { + // if the right side is a single variable + Ok((variable, coefficient)) => match quad.try_linear() { + // if the left side is linear + Ok(l) => (None, Some((variable, l / &coefficient)), None), + // if the left side isn't linear + Err(quad) => ( + Some(Statement::Constraint( + quad, + LinComb::summand(coefficient, variable), + )), + None, + Some(variable), + ), + }, + Err(l) => (Some(Statement::Constraint(quad, l)), None, None), + }, }; // insert into the ignored set - match to_ignore { - Some(v) => { - self.ignore.insert(v); - } - None => {} + if let Some(v) = to_ignore { + self.ignore.insert(v); } // insert into the substitution map - match to_insert { - Some((k, v)) => { - self.substitution.insert(k, v.into_canonical()); - } - None => {} + if let Some((k, v)) = to_insert { + self.substitution.insert(k, v.into_canonical()); }; // decide whether the constraint should be kept - match keep_constraint { - false => vec![], - true => vec![Statement::Constraint(quad, lin)], + match constraint { + Some(c) => vec![c], + _ => vec![], } } Statement::Directive(d) => { let d = self.fold_directive(d); // check if the inputs are constants, ie reduce to the form `coeff * ~one` - let inputs = d + let inputs: Vec<_> = d .inputs .into_iter() // we need to reduce to the canonical form to interpret `a + 1 - a` as `1` .map(|i| i.reduce()) - .map(|q| match q.try_linear() { - Some(l) => match l.0.len() { - // 0 is constant and can be represented by an empty lincomb - 0 => Ok(T::from(0)), - _ => l - // try to match to a single summand `coeff * v` - .try_summand() - .map(|(variable, coefficient)| match variable { - // v must be ~one - v if v == FlatVariable::one() => Ok(coefficient), - _ => Err(LinComb::summand(coefficient, variable).into()), - }) - .unwrap_or(Err(l.into())), - }, - None => Err(q), + .map(|q| { + match q + .try_linear() + .map(|l| l.try_constant().map_err(|l| l.into())) + { + Ok(r) => r, + Err(e) => Err(e), + } }) .collect::>>>(); - match inputs.iter().all(|r| r.is_ok()) { + match inputs.iter().all(|i| i.is_ok()) { true => { // unwrap inputs to their constant value - let inputs = inputs.into_iter().map(|i| i.unwrap()).collect(); + let inputs: Vec<_> = inputs.into_iter().map(|i| i.unwrap()).collect(); // run the solver - let outputs = d.solver.execute(&inputs).unwrap(); - + let outputs = Interpreter::default() + .execute_solver(&d.solver, &inputs) + .unwrap(); assert_eq!(outputs.len(), d.outputs.len()); // insert the results in the substitution @@ -160,8 +152,8 @@ impl Folder for RedefinitionOptimizer { vec![] } false => { - // reconstruct the input expressions - let inputs = inputs + //reconstruct the input expressions + let inputs: Vec<_> = inputs .into_iter() .map(|i| { i.map(|v| LinComb::summand(v, FlatVariable::one()).into()) @@ -183,8 +175,7 @@ impl Folder for RedefinitionOptimizer { match lc .0 .iter() - .find(|(variable, _)| self.substitution.get(&variable).is_some()) - .is_some() + .any(|(variable, _)| self.substitution.get(&variable).is_some()) { true => // for each summand, check if it is equal to a linear term in our substitution, otherwise keep it as is @@ -194,7 +185,7 @@ impl Folder for RedefinitionOptimizer { self.substitution .get(&variable) .map(|l| LinComb::from(l.clone()) * &coefficient) - .unwrap_or(LinComb::summand(coefficient, variable)) + .unwrap_or_else(|| LinComb::summand(coefficient, variable)) }) .fold(LinComb::zero(), |acc, x| acc + x) } @@ -209,9 +200,6 @@ impl Folder for RedefinitionOptimizer { } fn fold_function(&mut self, fun: Function) -> Function { - self.substitution.drain(); - self.ignore.drain(); - // to prevent the optimiser from replacing outputs, add them to the ignored set self.ignore.extend(fun.returns.iter().cloned()); @@ -378,7 +366,7 @@ mod tests { // -> // def main(x, y) -> (1): - // 6*x + 6*y == 6*x + 6*y // will be eliminated as a tautology + // 1*x + 1*y + 2*x + 2*y + 3*x + 3*y == 6*x + 6*y // will be eliminated as a tautology // return 6*x + 6*y let x = FlatVariable::new(0); @@ -412,7 +400,15 @@ mod tests { LinComb::summand(6, x) + LinComb::summand(6, y), LinComb::summand(6, x) + LinComb::summand(6, y), ), - Statement::definition(r, LinComb::summand(6, x) + LinComb::summand(6, y)), + Statement::definition( + r, + LinComb::summand(1, x) + + LinComb::summand(1, y) + + LinComb::summand(2, x) + + LinComb::summand(2, y) + + LinComb::summand(3, x) + + LinComb::summand(3, y), + ), ], returns: vec![r], }; diff --git a/zokrates_core/src/optimizer/tautology.rs b/zokrates_core/src/optimizer/tautology.rs index d6558412..c8a4df31 100644 --- a/zokrates_core/src/optimizer/tautology.rs +++ b/zokrates_core/src/optimizer/tautology.rs @@ -25,17 +25,16 @@ impl TautologyOptimizer { impl Folder for TautologyOptimizer { fn fold_statement(&mut self, s: Statement) -> Vec> { match s { - Statement::Constraint(quad, lin) => { - match quad.try_linear() { - Some(l) => { - if l == lin { - return vec![]; - } + Statement::Constraint(quad, lin) => match quad.try_linear() { + Ok(l) => { + if l == lin { + vec![] + } else { + vec![Statement::Constraint(l.into(), lin)] } - None => {} } - vec![Statement::Constraint(quad, lin)] - } + Err(quad) => vec![Statement::Constraint(quad, lin)], + }, _ => fold_statement(self, s), } } diff --git a/zokrates_core/src/proof_system/ark/gm17.rs b/zokrates_core/src/proof_system/ark/gm17.rs index 19888fdf..7c2d5026 100644 --- a/zokrates_core/src/proof_system/ark/gm17.rs +++ b/zokrates_core/src/proof_system/ark/gm17.rs @@ -78,7 +78,7 @@ impl Backend for Ark { query: vk .query .into_iter() - .map(|g1| serialization::to_g1::(g1)) + .map(serialization::to_g1::) .collect(), }; @@ -172,7 +172,7 @@ impl Backend for Ark { query: vk .query .into_iter() - .map(|g1| serialization::to_g1::(g1)) + .map(serialization::to_g1::) .collect(), }; diff --git a/zokrates_core/src/proof_system/ark/mod.rs b/zokrates_core/src/proof_system/ark/mod.rs index a21fc448..1e413dbb 100644 --- a/zokrates_core/src/proof_system/ark/mod.rs +++ b/zokrates_core/src/proof_system/ark/mod.rs @@ -49,42 +49,33 @@ fn ark_combination( cs: &mut ConstraintSystem<<::ArkEngine as PairingEngine>::Fr>, symbols: &mut BTreeMap, witness: &mut Witness, -) -> Result< - LinearCombination<<::ArkEngine as PairingEngine>::Fr>, - SynthesisError, -> { - let lc = - l.0.into_iter() - .map(|(k, v)| { - ( - v.into_ark(), - symbols - .entry(k) - .or_insert_with(|| { - match k.is_output() { - true => cs.new_input_variable(|| { - Ok(witness - .0 - .remove(&k) - .ok_or(SynthesisError::AssignmentMissing)? - .into_ark()) - }), - false => cs.new_witness_variable(|| { - Ok(witness - .0 - .remove(&k) - .ok_or(SynthesisError::AssignmentMissing)? - .into_ark()) - }), - } - .unwrap() - }) - .clone(), - ) - }) - .fold(LinearCombination::zero(), |acc, e| acc + e); - - Ok(lc) +) -> LinearCombination<<::ArkEngine as PairingEngine>::Fr> { + l.0.into_iter() + .map(|(k, v)| { + ( + v.into_ark(), + *symbols.entry(k).or_insert_with(|| { + match k.is_output() { + true => cs.new_input_variable(|| { + Ok(witness + .0 + .remove(&k) + .ok_or(SynthesisError::AssignmentMissing)? + .into_ark()) + }), + false => cs.new_witness_variable(|| { + Ok(witness + .0 + .remove(&k) + .ok_or(SynthesisError::AssignmentMissing)? + .into_ark()) + }), + } + .unwrap() + }), + ) + }) + .fold(LinearCombination::zero(), |acc, e| acc + e) } impl Prog { @@ -96,7 +87,7 @@ impl Prog { // mapping from IR variables let mut symbols = BTreeMap::new(); - let mut witness = witness.unwrap_or(Witness::empty()); + let mut witness = witness.unwrap_or_else(Witness::empty); assert!(symbols.insert(FlatVariable::one(), ConstraintSystem::<<::ArkEngine as PairingEngine>::Fr>::one()).is_none()); @@ -127,37 +118,34 @@ impl Prog { }), } .unwrap(); - (var.clone(), wire) + (*var, wire) }), ); let main = self.main; for statement in main.statements { - match statement { - Statement::Constraint(quad, lin) => { - let a = ark_combination( - quad.left.clone().into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - )?; - let b = ark_combination( - quad.right.clone().into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - )?; - let c = ark_combination( - lin.into_canonical(), - &mut cs, - &mut symbols, - &mut witness, - )?; + if let Statement::Constraint(quad, lin) = statement { + let a = ark_combination( + quad.left.clone().into_canonical(), + &mut cs, + &mut symbols, + &mut witness, + ); + let b = ark_combination( + quad.right.clone().into_canonical(), + &mut cs, + &mut symbols, + &mut witness, + ); + let c = ark_combination( + lin.into_canonical(), + &mut cs, + &mut symbols, + &mut witness, + ); - cs.enforce_constraint(a, b, c)?; - } - _ => {} + cs.enforce_constraint(a, b, c)?; } } diff --git a/zokrates_core/src/proof_system/bellman/groth16.rs b/zokrates_core/src/proof_system/bellman/groth16.rs index b8aeb830..9b1e1130 100644 --- a/zokrates_core/src/proof_system/bellman/groth16.rs +++ b/zokrates_core/src/proof_system/bellman/groth16.rs @@ -81,7 +81,7 @@ impl Backend for Bellman { ic: vk .gamma_abc .into_iter() - .map(|g1| serialization::to_g1::(g1)) + .map(serialization::to_g1::) .collect(), }; diff --git a/zokrates_core/src/proof_system/bellman/mod.rs b/zokrates_core/src/proof_system/bellman/mod.rs index 0368fad6..c1242fee 100644 --- a/zokrates_core/src/proof_system/bellman/mod.rs +++ b/zokrates_core/src/proof_system/bellman/mod.rs @@ -51,34 +51,31 @@ fn bellman_combination cs.alloc_input( - || format!("{}", k), - || { - Ok(witness - .0 - .remove(&k) - .ok_or(SynthesisError::AssignmentMissing)? - .into_bellman()) - }, - ), - false => cs.alloc( - || format!("{}", k), - || { - Ok(witness - .0 - .remove(&k) - .ok_or(SynthesisError::AssignmentMissing)? - .into_bellman()) - }, - ), - } - .unwrap() - }) - .clone(), + *symbols.entry(k).or_insert_with(|| { + match k.is_output() { + true => cs.alloc_input( + || format!("{}", k), + || { + Ok(witness + .0 + .remove(&k) + .ok_or(SynthesisError::AssignmentMissing)? + .into_bellman()) + }, + ), + false => cs.alloc( + || format!("{}", k), + || { + Ok(witness + .0 + .remove(&k) + .ok_or(SynthesisError::AssignmentMissing)? + .into_bellman()) + }, + ), + } + .unwrap() + }), ) }) .fold(LinearCombination::zero(), |acc, e| acc + e) @@ -93,7 +90,7 @@ impl Prog { // mapping from IR variables let mut symbols = BTreeMap::new(); - let mut witness = witness.unwrap_or(Witness::empty()); + let mut witness = witness.unwrap_or_else(Witness::empty); assert!(symbols.insert(FlatVariable::one(), CS::one()).is_none()); @@ -127,33 +124,29 @@ impl Prog { ), } .unwrap(); - (var.clone(), wire) + (*var, wire) }), ); let main = self.main; for statement in main.statements { - match statement { - Statement::Constraint(quad, lin) => { - let a = &bellman_combination( - quad.left.into_canonical(), - cs, - &mut symbols, - &mut witness, - ); - let b = &bellman_combination( - quad.right.into_canonical(), - cs, - &mut symbols, - &mut witness, - ); - let c = - &bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness); + if let Statement::Constraint(quad, lin) = statement { + let a = &bellman_combination( + quad.left.into_canonical(), + cs, + &mut symbols, + &mut witness, + ); + let b = &bellman_combination( + quad.right.into_canonical(), + cs, + &mut symbols, + &mut witness, + ); + let c = &bellman_combination(lin.into_canonical(), cs, &mut symbols, &mut witness); - cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c); - } - _ => {} + cs.enforce(|| "Constraint", |lc| lc + a, |lc| lc + b, |lc| lc + c); } } diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 404817c5..1ae3d683 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -7,7 +7,8 @@ use crate::absy::Identifier; use crate::absy::*; use crate::typed_absy::*; -use crate::typed_absy::{Parameter, Variable}; +use crate::typed_absy::{DeclarationParameter, DeclarationVariable, Variable}; +use num_bigint::BigUint; use std::collections::{hash_map::Entry, BTreeSet, HashMap, HashSet}; use std::fmt; use std::path::PathBuf; @@ -16,11 +17,15 @@ use zokrates_field::Field; use crate::parser::Position; use crate::absy::types::{UnresolvedSignature, UnresolvedType, UserTypeId}; -use crate::typed_absy::types::{FunctionKey, Signature, StructLocation, Type}; -use crate::typed_absy::types::{ArrayType, StructMember}; +use crate::typed_absy::types::{ + ArrayType, Constant, DeclarationArrayType, DeclarationFunctionKey, DeclarationSignature, + DeclarationStructMember, DeclarationStructType, DeclarationType, StructLocation, +}; use std::hash::{Hash, Hasher}; +use std::convert::TryInto; + #[derive(PartialEq, Debug)] pub struct ErrorInner { pos: Option<(Position, Position)>, @@ -42,33 +47,33 @@ impl ErrorInner { } } -type TypeMap = HashMap>; +type TypeMap<'ast> = HashMap>>; /// The global state of the program during semantic checks #[derive(Debug)] -struct State<'ast, T: Field> { +struct State<'ast, T> { /// The modules yet to be checked, which we consume as we explore the dependency tree modules: Modules<'ast>, /// The already checked modules, which we're returning at the end typed_modules: TypedModules<'ast, T>, /// The user-defined types, which we keep track at this phase only. In later phases, we rely only on basic types and combinations thereof - types: TypeMap, + types: TypeMap<'ast>, } /// A symbol for a given name: either a type or a group of functions. Not both! #[derive(PartialEq, Hash, Eq, Debug)] -enum SymbolType { +enum SymbolType<'ast> { Type, - Functions(BTreeSet), + Functions(BTreeSet>), } /// A data structure to keep track of all symbols in a module #[derive(Default)] -struct SymbolUnifier { - symbols: HashMap, +struct SymbolUnifier<'ast> { + symbols: HashMap>, } -impl SymbolUnifier { +impl<'ast> SymbolUnifier<'ast> { fn insert_type>(&mut self, id: S) -> bool { let s_type = self.symbols.entry(id.into()); match s_type { @@ -82,7 +87,11 @@ impl SymbolUnifier { } } - fn insert_function>(&mut self, id: S, signature: Signature) -> bool { + fn insert_function>( + &mut self, + id: S, + signature: DeclarationSignature<'ast>, + ) -> bool { let s_type = self.symbols.entry(id.into()); match s_type { // if anything is already called `id`, it depends what it is @@ -90,7 +99,7 @@ impl SymbolUnifier { match o.get_mut() { // if it's a Type, then we can't introduce a function SymbolType::Type => false, - // if it's a Function, we can introduce a new function only if it has a different signature + // if it's a Function, we can introduce it only if it has a different signature SymbolType::Functions(signatures) => signatures.insert(signature), } } @@ -118,21 +127,21 @@ impl fmt::Display for ErrorInner { let location = self .pos .map(|p| format!("{}", p.0)) - .unwrap_or("?".to_string()); + .unwrap_or_else(|| "?".to_string()); write!(f, "{}\n\t{}", location, self.message) } } /// A function query in the current module. #[derive(Debug)] -struct FunctionQuery<'ast> { +struct FunctionQuery<'ast, T> { id: Identifier<'ast>, - inputs: Vec, + inputs: Vec>, /// Output types are optional as we try to infer them - outputs: Vec>, + outputs: Vec>>, } -impl<'ast> fmt::Display for FunctionQuery<'ast> { +impl<'ast, T: fmt::Display> fmt::Display for FunctionQuery<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { @@ -150,7 +159,7 @@ impl<'ast> fmt::Display for FunctionQuery<'ast> { " -> {}", match &self.outputs[0] { Some(t) => format!("{}", t), - None => format!("_"), + None => "_".into(), } ), _ => { @@ -170,68 +179,87 @@ impl<'ast> fmt::Display for FunctionQuery<'ast> { } } -impl<'ast> FunctionQuery<'ast> { +impl<'ast, T: Field> FunctionQuery<'ast, T> { /// Create a new query. fn new( id: Identifier<'ast>, - inputs: &Vec, - outputs: &Vec>, - ) -> FunctionQuery<'ast> { + inputs: &[Type<'ast, T>], + outputs: &[Option>], + ) -> Self { FunctionQuery { id, - inputs: inputs.clone(), - outputs: outputs.clone(), + inputs: inputs.to_owned(), + outputs: outputs.to_owned(), } } /// match a `FunctionKey` against this `FunctionQuery` - fn match_func(&self, func: &FunctionKey) -> bool { + fn match_func(&self, func: &DeclarationFunctionKey<'ast>) -> bool { self.id == func.id - && self.inputs == func.signature.inputs + && self + .inputs + .iter() + .zip(func.signature.inputs.iter()) + .all(|(input_ty, sig_ty)| input_ty.can_be_specialized_to(&sig_ty)) && self.outputs.len() == func.signature.outputs.len() - && self.outputs.iter().enumerate().all(|(index, t)| match t { - Some(ref t) => t == &func.signature.outputs[index], - _ => true, - }) + && self + .outputs + .iter() + .zip(func.signature.outputs.iter()) + .all(|(output_ty, sig_ty)| { + output_ty + .as_ref() + .map(|output_ty| output_ty.can_be_specialized_to(&sig_ty)) + .unwrap_or(true) + }) } - fn match_funcs(&self, funcs: &HashSet>) -> Option> { - funcs.iter().find(|func| self.match_func(func)).cloned() + fn match_funcs( + &self, + funcs: &HashSet>, + ) -> Vec> { + funcs + .iter() + .filter(|func| self.match_func(func)) + .cloned() + .collect() } } /// A scoped variable, so that we can delete all variables of a given scope when exiting it #[derive(Clone, Debug)] -pub struct ScopedVariable<'ast> { - id: Variable<'ast>, +pub struct ScopedVariable<'ast, T> { + id: Variable<'ast, T>, level: usize, } /// Identifiers of different `ScopedVariable`s should not conflict, so we define them as equivalent -impl<'ast> PartialEq for ScopedVariable<'ast> { - fn eq(&self, other: &ScopedVariable) -> bool { +impl<'ast, T> PartialEq for ScopedVariable<'ast, T> { + fn eq(&self, other: &Self) -> bool { self.id.id == other.id.id } } -impl<'ast> Hash for ScopedVariable<'ast> { +impl<'ast, T> Hash for ScopedVariable<'ast, T> { fn hash(&self, state: &mut H) { self.id.id.hash(state); } } -impl<'ast> Eq for ScopedVariable<'ast> {} +impl<'ast, T> Eq for ScopedVariable<'ast, T> {} /// Checker checks the semantics of a program, keeping track of functions and variables in scope -pub struct Checker<'ast> { - scope: HashSet>, - functions: HashSet>, +pub struct Checker<'ast, T> { + return_types: Option>>, + scope: HashSet>, + functions: HashSet>, level: usize, } -impl<'ast> Checker<'ast> { - fn new() -> Checker<'ast> { +impl<'ast, T: Field> Checker<'ast, T> { + fn new() -> Self { Checker { + return_types: None, scope: HashSet::new(), functions: HashSet::new(), level: 0, @@ -243,11 +271,11 @@ impl<'ast> Checker<'ast> { /// # Arguments /// /// * `prog` - The `Program` to be checked - pub fn check(prog: Program<'ast>) -> Result, Vec> { + pub fn check(prog: Program<'ast>) -> Result, Vec> { Checker::new().check_program(prog) } - fn check_program( + fn check_program( &mut self, program: Program<'ast>, ) -> Result, Vec> { @@ -261,7 +289,7 @@ impl<'ast> Checker<'ast> { Err(e) => errors.extend(e), }; - if errors.len() > 0 { + if !errors.is_empty() { return Err(errors); } @@ -287,8 +315,8 @@ impl<'ast> Checker<'ast> { id: String, s: StructDefinitionNode<'ast>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result> { + types: &TypeMap<'ast>, + ) -> Result, Vec> { let pos = s.pos(); let s = s.value; @@ -299,7 +327,7 @@ impl<'ast> Checker<'ast> { for field in s.fields { let member_id = field.value.id.to_string(); match self - .check_type(field.value.ty, module_id, &types) + .check_declaration_type(field.value.ty, module_id, &types, &HashSet::new()) .map(|t| (member_id, t)) { Ok(f) => match fields_set.insert(f.0.clone()) { @@ -315,27 +343,27 @@ impl<'ast> Checker<'ast> { } } - if errors.len() > 0 { + if !errors.is_empty() { return Err(errors); } - Ok(Type::Struct(StructType::new( + Ok(DeclarationType::Struct(DeclarationStructType::new( module_id.into(), id, fields .iter() - .map(|f| StructMember::new(f.0.clone(), f.1.clone())) + .map(|f| DeclarationStructMember::new(f.0.clone(), f.1.clone())) .collect(), ))) } - fn check_symbol_declaration( + fn check_symbol_declaration( &mut self, declaration: SymbolDeclarationNode<'ast>, module_id: &ModuleId, state: &mut State<'ast, T>, - functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, - symbol_unifier: &mut SymbolUnifier, + functions: &mut HashMap, TypedFunctionSymbol<'ast, T>>, + symbol_unifier: &mut SymbolUnifier<'ast>, ) -> Result<(), Vec> { let mut errors: Vec = vec![]; @@ -362,13 +390,16 @@ impl<'ast> Checker<'ast> { } .in_file(module_id), ), - true => {} + true => { + // there should be no entry in the map for this type yet + assert!(state + .types + .entry(module_id.clone()) + .or_default() + .insert(declaration.id.to_string(), ty) + .is_none()); + } }; - state - .types - .entry(module_id.clone()) - .or_default() - .insert(declaration.id.to_string(), ty); } Err(e) => errors.extend(e.into_iter().map(|inner| Error { inner, @@ -393,11 +424,11 @@ impl<'ast> Checker<'ast> { }; self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) + DeclarationFunctionKey::with_location(module_id.clone(), declaration.id) .signature(funct.signature.clone()), ); functions.insert( - FunctionKey::with_id(declaration.id.clone()) + DeclarationFunctionKey::with_location(module_id.clone(), declaration.id) .signature(funct.signature.clone()), TypedFunctionSymbol::Here(funct), ); @@ -420,8 +451,9 @@ impl<'ast> Checker<'ast> { .functions .iter() .filter(|(k, _)| k.id == import.symbol_id) - .map(|(_, v)| FunctionKey { - id: import.symbol_id.clone(), + .map(|(_, v)| DeclarationFunctionKey { + module: import.module_id.clone(), + id: import.symbol_id, signature: v.signature(&state.typed_modules).clone(), }) .collect(); @@ -439,7 +471,7 @@ impl<'ast> Checker<'ast> { // rename the type to the declared symbol let t = match t { - Type::Struct(t) => Type::Struct(StructType { + DeclarationType::Struct(t) => DeclarationType::Struct(DeclarationStructType { location: Some(StructLocation { name: declaration.id.into(), module: module_id.clone() @@ -468,7 +500,7 @@ impl<'ast> Checker<'ast> { .types .entry(module_id.clone()) .or_default() - .insert(declaration.id.to_string(), t.clone()); + .insert(declaration.id.to_string(), t); } (0, None) => { errors.push(ErrorInner { @@ -496,12 +528,12 @@ impl<'ast> Checker<'ast> { true => {} }; - self.functions.insert(candidate.clone().id(declaration.id)); + let local_key = candidate.clone().id(declaration.id).module(module_id.clone()); + + self.functions.insert(local_key.clone()); functions.insert( - candidate.clone().id(declaration.id), - TypedFunctionSymbol::There( - candidate, - import.module_id.clone(), + local_key, + TypedFunctionSymbol::There(candidate, ), ); } @@ -514,7 +546,9 @@ impl<'ast> Checker<'ast> { }; } Symbol::Flat(funct) => { - match symbol_unifier.insert_function(declaration.id, funct.signature()) { + match symbol_unifier + .insert_function(declaration.id, funct.signature().try_into().unwrap()) + { false => { errors.push( ErrorInner { @@ -531,31 +565,30 @@ impl<'ast> Checker<'ast> { }; self.functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature().clone()), + DeclarationFunctionKey::with_location(module_id.clone(), declaration.id) + .signature(funct.signature().try_into().unwrap()), ); functions.insert( - FunctionKey::with_id(declaration.id.clone()) - .signature(funct.signature().clone()), + DeclarationFunctionKey::with_location(module_id.clone(), declaration.id) + .signature(funct.signature().try_into().unwrap()), TypedFunctionSymbol::Flat(funct), ); } }; // return if any errors occured - if errors.len() > 0 { + if !errors.is_empty() { return Err(errors); } Ok(()) } - fn check_module( + fn check_module( &mut self, module_id: &ModuleId, state: &mut State<'ast, T>, ) -> Result<(), Vec> { - let mut errors = vec![]; let mut checked_functions = HashMap::new(); // check if the module was already removed from the untyped ones @@ -574,18 +607,13 @@ impl<'ast> Checker<'ast> { // we go through symbol declarations and check them for declaration in module.symbols { - match self.check_symbol_declaration( + self.check_symbol_declaration( declaration, module_id, state, &mut checked_functions, &mut symbol_unifier, - ) { - Ok(()) => {} - Err(e) => { - errors.extend(e); - } - } + )? } Some(TypedModule { @@ -594,27 +622,19 @@ impl<'ast> Checker<'ast> { } }; - // return if any errors occured - if errors.len() > 0 { - return Err(errors); - } - // insert into typed_modules if we checked anything - match to_insert { - Some(typed_module) => { - // there should be no checked module at that key just yet, if there is we have a colision or we checked something twice - assert!(state - .typed_modules - .insert(module_id.clone(), typed_module) - .is_none()); - } - None => {} + if let Some(typed_module) = to_insert { + // there should be no checked module at that key just yet, if there is we have a colision or we checked something twice + assert!(state + .typed_modules + .insert(module_id.clone(), typed_module) + .is_none()); }; Ok(()) } - fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { + fn check_single_main(module: &TypedModule) -> Result<(), ErrorInner> { match module .functions .iter() @@ -624,7 +644,7 @@ impl<'ast> Checker<'ast> { 1 => Ok(()), 0 => Err(ErrorInner { pos: None, - message: format!("No main function found"), + message: "No main function found".into(), }), n => Err(ErrorInner { pos: None, @@ -633,9 +653,9 @@ impl<'ast> Checker<'ast> { } } - fn check_for_var(&self, var: &VariableNode) -> Result<(), ErrorInner> { + fn check_for_var(&self, var: &VariableNode<'ast>) -> Result<(), ErrorInner> { match var.value.get_type() { - UnresolvedType::FieldElement => Ok(()), + UnresolvedType::Uint(32) => Ok(()), t => Err(ErrorInner { pos: Some(var.pos()), message: format!("Variable in for loop cannot have type {}", t), @@ -643,37 +663,69 @@ impl<'ast> Checker<'ast> { } } - fn check_function( + fn check_function( &mut self, funct_node: FunctionNode<'ast>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, Vec> { + assert!(self.scope.is_empty()); + assert!(self.return_types.is_none()); + self.enter_scope(); let pos = funct_node.pos(); let mut errors = vec![]; let funct = funct_node.value; - let mut arguments_checked = vec![]; let mut signature = None; assert_eq!(funct.arguments.len(), funct.signature.inputs.len()); - for arg in funct.arguments { - match self.check_parameter(arg, module_id, types) { - Ok(a) => { - self.insert_into_scope(a.id.clone()); - arguments_checked.push(a); - } - Err(e) => errors.extend(e), - } - } + let mut arguments_checked = vec![]; let mut statements_checked = vec![]; match self.check_signature(funct.signature, module_id, types) { Ok(s) => { + // define variables for the constants + for generic in &s.generics { + let generic = generic.clone().unwrap(); // for declaration signatures, generics cannot be ignored + + let v = Variable::with_id_and_type( + match generic { + Constant::Generic(g) => g, + _ => unreachable!(), + }, + Type::Uint(UBitwidth::B32), + ); + // we don't have to check for conflicts here, because this was done when checking the signature + self.insert_into_scope(v.clone()); + } + + for (arg, decl_ty) in funct.arguments.into_iter().zip(s.inputs.iter()) { + let pos = arg.pos(); + + let arg = arg.value; + + let decl_v = + DeclarationVariable::with_id_and_type(arg.id.value.id, decl_ty.clone()); + + match self.insert_into_scope(decl_v.clone()) { + true => {} + false => { + errors.push(ErrorInner { + pos: Some(pos), + message: format!("Duplicate name in function definition: `{}` was previously declared as an argument or a generic constant", arg.id.value.id) + }); + } + }; + arguments_checked.push(DeclarationParameter { + id: decl_v, + private: arg.private, + }); + } + let mut found_return = false; for stat in funct.statements.into_iter() { @@ -683,7 +735,7 @@ impl<'ast> Checker<'ast> { if found_return { errors.push(ErrorInner { pos, - message: format!("Expected a single return statement",), + message: "Expected a single return statement".to_string(), }); } @@ -740,11 +792,14 @@ impl<'ast> Checker<'ast> { } }; - if errors.len() > 0 { + self.exit_scope(); + + if !errors.is_empty() { return Err(errors); } - self.exit_scope(); + self.return_types = None; + assert!(self.scope.is_empty()); Ok(TypedFunction { arguments: arguments_checked, @@ -753,32 +808,35 @@ impl<'ast> Checker<'ast> { }) } - fn check_parameter( - &self, - p: ParameterNode<'ast>, - module_id: &ModuleId, - types: &TypeMap, - ) -> Result, Vec> { - let var = self.check_variable(p.value.id, module_id, types)?; - - Ok(Parameter { - id: var, - private: p.value.private, - }) - } - fn check_signature( - &self, - signature: UnresolvedSignature, + &mut self, + signature: UnresolvedSignature<'ast>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result> { + types: &TypeMap<'ast>, + ) -> Result, Vec> { let mut errors = vec![]; let mut inputs = vec![]; let mut outputs = vec![]; + let mut generics = vec![]; + + let mut constants = HashSet::new(); + + for g in signature.generics { + match constants.insert(g.value) { + true => { + generics.push(Some(Constant::Generic(g.value))); + } + false => { + errors.push(ErrorInner { + pos: Some(g.pos()), + message: format!("Generic parameter {} is already declared", g.value), + }); + } + } + } for t in signature.inputs { - match self.check_type(t, module_id, types) { + match self.check_declaration_type(t, module_id, types, &constants) { Ok(t) => { inputs.push(t); } @@ -789,7 +847,7 @@ impl<'ast> Checker<'ast> { } for t in signature.outputs { - match self.check_type(t, module_id, types) { + match self.check_declaration_type(t, module_id, types, &constants) { Ok(t) => { outputs.push(t); } @@ -799,19 +857,25 @@ impl<'ast> Checker<'ast> { } } - if errors.len() > 0 { + if !errors.is_empty() { return Err(errors); } - Ok(Signature { inputs, outputs }) + self.return_types = Some(outputs.clone()); + + Ok(DeclarationSignature { + generics, + inputs, + outputs, + }) } fn check_type( - &self, - ty: UnresolvedTypeNode, + &mut self, + ty: UnresolvedTypeNode<'ast>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result { + types: &TypeMap<'ast>, + ) -> Result, ErrorInner> { let pos = ty.pos(); let ty = ty.value; @@ -819,10 +883,122 @@ impl<'ast> Checker<'ast> { UnresolvedType::FieldElement => Ok(Type::FieldElement), UnresolvedType::Boolean => Ok(Type::Boolean), UnresolvedType::Uint(bitwidth) => Ok(Type::uint(bitwidth)), - UnresolvedType::Array(t, size) => Ok(Type::Array(ArrayType::new( - self.check_type(*t, module_id, types)?, - size, - ))), + UnresolvedType::Array(t, size) => { + let size = self.check_expression(size, module_id, types)?; + + let ty = size.get_type(); + + let size = match size { + TypedExpression::Uint(e) => match e.bitwidth() { + UBitwidth::B32 => Ok(e), + _ => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant, found {} of type {}", + e, ty + ), + }), + }, + TypedExpression::Int(v) => UExpression::try_from_int(v.clone(), UBitwidth::B32) + .map_err(|_| ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant, found {} of type {}", + v, ty + ), + }), + _ => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant, found {} of type {}", + size, ty + ), + }), + }?; + + Ok(Type::Array(ArrayType::new( + self.check_type(*t, module_id, types)?, + size, + ))) + } + UnresolvedType::User(id) => types + .get(module_id) + .unwrap() + .get(&id) + .cloned() + .ok_or_else(|| ErrorInner { + pos: Some(pos), + message: format!("Undefined type {}", id), + }) + .map(|t| t.into()), + } + } + + fn check_generic_expression( + &mut self, + expr: ExpressionNode<'ast>, + ) -> Result, ErrorInner> { + let pos = expr.pos(); + + match expr.value { + Expression::U32Constant(c) => Ok(Constant::Concrete(c)), + Expression::IntConstant(c) => { + if c <= BigUint::from(2u128.pow(32) - 1) { + Ok(Constant::Concrete( + u32::from_str_radix(&c.to_str_radix(16), 16).unwrap(), + )) + } else { + Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant or an identifier, found {}", + Expression::IntConstant(c) + ), + }) + } + } + Expression::Identifier(name) => Ok(Constant::Generic(name)), + e => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected array dimension to be a u32 constant or an identifier, found {}", + e + ), + }), + } + } + + fn check_declaration_type( + &mut self, + ty: UnresolvedTypeNode<'ast>, + module_id: &ModuleId, + types: &TypeMap<'ast>, + constants: &HashSet>, + ) -> Result, ErrorInner> { + let pos = ty.pos(); + let ty = ty.value; + + match ty { + UnresolvedType::FieldElement => Ok(DeclarationType::FieldElement), + UnresolvedType::Boolean => Ok(DeclarationType::Boolean), + UnresolvedType::Uint(bitwidth) => Ok(DeclarationType::uint(bitwidth)), + UnresolvedType::Array(t, size) => { + let checked_size = self.check_generic_expression(size.clone())?; + + if let Constant::Generic(g) = checked_size { + if !constants.contains(g) { + return Err(ErrorInner { + pos: Some(pos), + message: format!("Undeclared generic parameter in function definition: `{}` isn\'t declared as a generic constant", g) + }); + } + }; + + Ok(DeclarationType::Array(DeclarationArrayType::new( + self.check_declaration_type(*t, module_id, types, constants)?, + checked_size, + ))) + } UnresolvedType::User(id) => { types .get(module_id) @@ -838,11 +1014,11 @@ impl<'ast> Checker<'ast> { } fn check_variable( - &self, + &mut self, v: crate::absy::VariableNode<'ast>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result, Vec> { + types: &TypeMap<'ast>, + ) -> Result, Vec> { Ok(Variable::with_id_and_type( v.value.id, self.check_type(v.value._type, module_id, types) @@ -850,26 +1026,182 @@ impl<'ast> Checker<'ast> { )) } - fn check_statement( + fn check_for_loop( + &mut self, + var: crate::absy::VariableNode<'ast>, + range: (ExpressionNode<'ast>, ExpressionNode<'ast>), + statements: Vec>, + pos: (Position, Position), + module_id: &ModuleId, + types: &TypeMap<'ast>, + ) -> Result, Vec> { + self.check_for_var(&var).map_err(|e| vec![e])?; + + let var = self.check_variable(var, module_id, types).unwrap(); + + let from = self + .check_expression(range.0, module_id, &types) + .map_err(|e| vec![e])?; + let to = self + .check_expression(range.1, module_id, &types) + .map_err(|e| vec![e])?; + + let from = match from { + TypedExpression::Uint(from) => match from.bitwidth() { + UBitwidth::B32 => Ok(from), + bitwidth => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected lower loop bound to be of type u32, found {}", + Type::::Uint(bitwidth) + ), + }), + }, + TypedExpression::Int(v) => { + UExpression::try_from_int(v, UBitwidth::B32).map_err(|_| ErrorInner { + pos: Some(pos), + message: format!( + "Expected lower loop bound to be of type u32, found {}", + Type::::Int + ), + }) + } + from => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected lower loop bound to be of type u32, found {}", + from.get_type() + ), + }), + } + .map_err(|e| vec![e])?; + + let to = match to { + TypedExpression::Uint(to) => match to.bitwidth() { + UBitwidth::B32 => Ok(to), + bitwidth => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected upper loop bound to be of type u32, found {}", + Type::::Uint(bitwidth) + ), + }), + }, + TypedExpression::Int(v) => { + UExpression::try_from_int(v, UBitwidth::B32).map_err(|_| ErrorInner { + pos: Some(pos), + message: format!( + "Expected upper loop bound to be of type u32, found {}", + Type::::Int + ), + }) + } + to => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Expected upper loop bound to be of type u32, found {}", + to.get_type() + ), + }), + } + .map_err(|e| vec![e])?; + + self.insert_into_scope(var.clone()); + + let mut checked_statements = vec![]; + + for stat in statements { + let checked_stat = self.check_statement(stat, module_id, types)?; + checked_statements.push(checked_stat); + } + + Ok(TypedStatement::For(var, from, to, checked_statements)) + } + + fn check_statement( &mut self, stat: StatementNode<'ast>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, Vec> { let pos = stat.pos(); match stat.value { - Statement::Return(list) => { + Statement::Return(e) => { let mut expression_list_checked = vec![]; + let mut errors = vec![]; - for e in list.value.expressions { + // we clone the return types because there might be other return statements + let return_types = self.return_types.clone().unwrap(); + + for e in e.value.expressions.into_iter() { let e_checked = self .check_expression(e, module_id, &types) .map_err(|e| vec![e])?; expression_list_checked.push(e_checked); } - Ok(TypedStatement::Return(expression_list_checked)) + let res = match expression_list_checked.len() == return_types.len() { + true => match expression_list_checked + .iter() + .zip(return_types.clone()) + .map(|(e, t)| TypedExpression::align_to_type(e.clone(), t.into())) + .collect::, _>>() + .map_err(|e| { + vec![ErrorInner { + pos: Some(pos), + message: format!( + "Expected return value to be of type {}, found {}", + e.1, e.0 + ), + }] + }) { + Ok(e) => { + match e.iter().map(|e| e.get_type()).collect::>() == return_types + { + true => {} + false => errors.push(ErrorInner { + pos: Some(pos), + message: format!( + "Expected ({}) in return statement, found ({})", + return_types + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", "), + e.iter() + .map(|e| e.get_type()) + .map(|t| t.to_string()) + .collect::>() + .join(", ") + ), + }), + }; + TypedStatement::Return(e) + } + Err(err) => { + errors.extend(err); + TypedStatement::Return(expression_list_checked) + } + }, + false => { + errors.push(ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} expressions in return statement, found {}", + return_types.len(), + expression_list_checked.len() + ), + }); + TypedStatement::Return(expression_list_checked) + } + }; + + if !errors.is_empty() { + return Err(errors); + } + + Ok(res) } Statement::Declaration(var) => { let var = self.check_variable(var, module_id, types)?; @@ -885,16 +1217,14 @@ impl<'ast> Checker<'ast> { Statement::Definition(assignee, expr) => { // we create multidef when rhs is a function call to benefit from inference // check rhs is not a function call here - match expr.value { - Expression::FunctionCall(..) => panic!("Parser should not generate Definition where the right hand side is a FunctionCall"), - _ => {} - } + if let Expression::FunctionCall(..) = expr.value { + panic!("Parser should not generate Definition where the right hand side is a FunctionCall") + } // check the expression to be assigned let checked_expr = self .check_expression(expr, module_id, &types) .map_err(|e| vec![e])?; - let expression_type = checked_expr.get_type(); // check that the assignee is declared and is well formed let var = self @@ -904,16 +1234,35 @@ impl<'ast> Checker<'ast> { let var_type = var.get_type(); // make sure the assignee has the same type as the rhs - match var_type == expression_type { - true => Ok(TypedStatement::Definition(var, checked_expr)), - false => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expression {} of type {} cannot be assigned to {} of type {}", - checked_expr, expression_type, var, var_type - ), - }), + match var_type.clone() { + Type::FieldElement => FieldElementExpression::try_from_typed(checked_expr) + .map(TypedExpression::from), + Type::Boolean => { + BooleanExpression::try_from_typed(checked_expr).map(TypedExpression::from) + } + Type::Uint(bitwidth) => UExpression::try_from_typed(checked_expr, bitwidth) + .map(TypedExpression::from), + Type::Array(array_ty) => { + ArrayExpression::try_from_typed(checked_expr, *array_ty.ty) + .map(TypedExpression::from) + } + Type::Struct(struct_ty) => { + StructExpression::try_from_typed(checked_expr, struct_ty) + .map(TypedExpression::from) + } + Type::Int => Err(checked_expr), // Integers cannot be assigned } + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expression {} of type {} cannot be assigned to {} of type {}", + e, + e.get_type(), + var.clone(), + var_type + ), + }) + .map(|rhs| TypedStatement::Definition(var, rhs)) .map_err(|e| vec![e]) } Statement::Assertion(e) => { @@ -937,68 +1286,50 @@ impl<'ast> Checker<'ast> { Statement::For(var, from, to, statements) => { self.enter_scope(); - self.check_for_var(&var).map_err(|e| vec![e])?; - - let var = self.check_variable(var, module_id, types).unwrap(); - - let from = self - .check_expression(from, module_id, &types) - .map_err(|e| vec![e])?; - let to = self - .check_expression(to, module_id, &types) - .map_err(|e| vec![e])?; - - let from = match from { - TypedExpression::FieldElement(e) => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected lower loop bound to be of type field, found {}", - e.get_type() - ), - }), - } - .map_err(|e| vec![e])?; - - let to = match to { - TypedExpression::FieldElement(e) => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected higher loop bound to be of type field, found {}", - e.get_type() - ), - }), - } - .map_err(|e| vec![e])?; - - self.insert_into_scope(var.clone()); - - let mut checked_statements = vec![]; - - for stat in statements { - let checked_stat = self.check_statement(stat, module_id, types)?; - checked_statements.push(checked_stat); - } + let res = self.check_for_loop(var, (from, to), statements, pos, module_id, types); self.exit_scope(); - Ok(TypedStatement::For(var, from, to, checked_statements)) + + res } Statement::MultipleDefinition(assignees, rhs) => { match rhs.value { // Right side has to be a function call - Expression::FunctionCall(fun_id, arguments) => { + Expression::FunctionCall(fun_id, generics, arguments) => { + // check the generic arguments, if any + let generics_checked: Option>>> = generics + .map(|generics| + generics.into_iter().map(|g| + g.map(|g| { + let pos = g.pos(); + self.check_expression(g, module_id, &types).and_then(|g| { + UExpression::try_from_typed(g, UBitwidth::B32).map_err( + |e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} to be of type u32, found {}", + e, + e.get_type(), + ), + }, + ) + }) + }) + .transpose() + ) + .collect::>() + ).transpose().map_err(|e| vec![e])?; // check lhs assignees are defined - let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee::(a, module_id, types)).partition(|r| r.is_ok()); + let (assignees, errors): (Vec<_>, Vec<_>) = assignees.into_iter().map(|a| self.check_assignee(a, module_id, types)).partition(|r| r.is_ok()); - if errors.len() > 0 { + if !errors.is_empty() { return Err(errors.into_iter().map(|e| e.unwrap_err()).collect()); } let assignees: Vec<_> = assignees.into_iter().map(|a| a.unwrap()).collect(); - let assignee_types = assignees.iter().map(|a| Some(a.get_type().clone())).collect(); + let assignee_types: Vec<_> = assignees.iter().map(|a| Some(a.get_type().clone())).collect(); // find argument types let mut arguments_checked = vec![]; @@ -1007,23 +1338,35 @@ impl<'ast> Checker<'ast> { arguments_checked.push(arg_checked); } - let arguments_types = + let arguments_types: Vec<_> = arguments_checked.iter().map(|a| a.get_type()).collect(); let query = FunctionQuery::new(&fun_id, &arguments_types, &assignee_types); - let f = self.find_function(&query); + let functions = self.find_functions(&query); - match f { + match functions.len() { // the function has to be defined - Some(f) => { + 1 => { - let call = TypedExpressionList::FunctionCall(f.clone(), arguments_checked, f.signature.outputs.clone()); + let mut functions = functions; + let f = functions.pop().unwrap(); + + let arguments_checked = arguments_checked.into_iter().zip(f.signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::, _>>().map_err(|e| vec![ErrorInner { + pos: Some(pos), + message: format!("Expected function call argument to be of type {}, found {} of type {}", e.1, e.0, e.0.get_type()) + }])?; + + let call = TypedExpressionList::FunctionCall(f.clone(), generics_checked.unwrap_or_else(|| vec![None; f.signature.generics.len()]), arguments_checked, assignees.iter().map(|a| a.get_type()).collect()); Ok(TypedStatement::MultipleDefinition(assignees, call)) }, - None => Err(ErrorInner { pos: Some(pos), + 0 => Err(ErrorInner { pos: Some(pos), message: format!("Function definition for function {} with signature {} not found.", fun_id, query) }), + n => Err(ErrorInner { + pos: Some(pos), + message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) + }) } } _ => Err(ErrorInner { @@ -1035,11 +1378,11 @@ impl<'ast> Checker<'ast> { } } - fn check_assignee( + fn check_assignee( &mut self, assignee: AssigneeNode<'ast>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = assignee.pos(); // check that the assignee is declared @@ -1070,18 +1413,17 @@ impl<'ast> Checker<'ast> { ), }; - let checked_typed_index = match checked_index { - TypedExpression::FieldElement(e) => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - - message: format!( - "Expected array {} index to have type field, found {}", - checked_assignee, - e.get_type() - ), - }), - }?; + let checked_typed_index = + UExpression::try_from_typed(checked_index, UBitwidth::B32).map_err( + |e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected array {} index to have type u32, found {}", + checked_assignee, + e.get_type() + ), + }, + )?; Ok(TypedAssignee::Select( box checked_assignee, @@ -1090,7 +1432,6 @@ impl<'ast> Checker<'ast> { } ty => Err(ErrorInner { pos: Some(pos), - message: format!( "Cannot access element at index {} on {} of type {}", index, checked_assignee, ty, @@ -1107,7 +1448,16 @@ impl<'ast> Checker<'ast> { Some(_) => Ok(TypedAssignee::Member(box checked_assignee, member.into())), None => Err(ErrorInner { pos: Some(pos), - message: format!("{} doesn't have member {}", ty, member), + message: format!( + "{} {{{}}} doesn't have member {}", + ty, + members + .iter() + .map(|m| format!("{}: {}", m.id, m.ty)) + .collect::>() + .join(", "), + member + ), }), }, ty => Err(ErrorInner { @@ -1123,12 +1473,12 @@ impl<'ast> Checker<'ast> { } } - fn check_spread_or_expression( + fn check_spread_or_expression( &mut self, spread_or_expression: SpreadOrExpression<'ast>, module_id: &ModuleId, - types: &TypeMap, - ) -> Result>, ErrorInner> { + types: &TypeMap<'ast>, + ) -> Result, ErrorInner> { match spread_or_expression { SpreadOrExpression::Spread(s) => { let pos = s.pos(); @@ -1136,80 +1486,33 @@ impl<'ast> Checker<'ast> { let checked_expression = self.check_expression(s.value.expression, module_id, &types)?; - let res = match checked_expression { - TypedExpression::Array(e) => { - let ty = e.inner_type().clone(); - - let size = e.size(); - match e.into_inner() { - // if we're doing a spread over an inline array, we return the inside of the array: ...[x, y, z] == x, y, z - // this is not strictly needed, but it makes spreads memory linear rather than quadratic - ArrayExpressionInner::Value(v) => Ok(v), - // otherwise we return a[0], ..., a[a.size() -1 ] - e => Ok((0..size) - .map(|i| match &ty { - Type::FieldElement => FieldElementExpression::select( - e.clone().annotate(Type::FieldElement, size), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Uint(bitwidth) => UExpression::select( - e.clone().annotate(Type::Uint(*bitwidth), size), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Boolean => BooleanExpression::select( - e.clone().annotate(Type::Boolean, size), - FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Array(array_type) => ArrayExpressionInner::Select( - box e - .clone() - .annotate(Type::Array(array_type.clone()), size), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(*array_type.ty.clone(), array_type.size) - .into(), - Type::Struct(members) => StructExpressionInner::Select( - box e.clone().annotate(Type::Struct(members.clone()), size), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(members.clone()) - .into(), - }) - .collect()), - } - } + match checked_expression { + TypedExpression::Array(a) => Ok(TypedExpressionOrSpread::Spread(a.into())), e => Err(ErrorInner { pos: Some(pos), - message: format!( "Expected spread operator to apply on array, found {}", e.get_type() ), }), - }; - - let res = res.unwrap(); - - Ok(res) - } - SpreadOrExpression::Expression(e) => { - self.check_expression(e, module_id, &types).map(|r| vec![r]) + } } + SpreadOrExpression::Expression(e) => self + .check_expression(e, module_id, &types) + .map(|r| r.into()), } } - fn check_expression( + fn check_expression( &mut self, expr: ExpressionNode<'ast>, module_id: &ModuleId, - types: &TypeMap, + types: &TypeMap<'ast>, ) -> Result, ErrorInner> { let pos = expr.pos(); match expr.value { + Expression::IntConstant(v) => Ok(IntExpression::Value(v).into()), Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()), Expression::Identifier(name) => { // check that `id` is defined in the scope @@ -1230,6 +1533,7 @@ impl<'ast> Checker<'ast> { Type::Struct(members) => Ok(StructExpressionInner::Identifier(name.into()) .annotate(members) .into()), + Type::Int => unreachable!(), }, None => Err(ErrorInner { pos: Some(pos), @@ -1241,14 +1545,25 @@ impl<'ast> Checker<'ast> { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + use self::TypedExpression::*; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `+` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { + (Int(e1), Int(e2)) => Ok(IntExpression::Add(box e1, box e2).into()), (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(FieldElementExpression::Add(box e1, box e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { - Ok(UExpression::add(e1, e2).into()) + Ok((e1 + e2).into()) } (t1, t2) => Err(ErrorInner { pos: Some(pos), @@ -1265,15 +1580,22 @@ impl<'ast> Checker<'ast> { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + use self::TypedExpression::*; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `-` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { - (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { + (Int(e1), Int(e2)) => Ok(IntExpression::Sub(box e1, box e2).into()), + (FieldElement(e1), FieldElement(e2)) => { Ok(FieldElementExpression::Sub(box e1, box e2).into()) } - (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) - if e1.get_type() == e2.get_type() => - { - Ok(UExpression::sub(e1, e2).into()) - } + (Uint(e1), Uint(e2)) if e1.get_type() == e2.get_type() => Ok((e1 - e2).into()), (t1, t2) => Err(ErrorInner { pos: Some(pos), @@ -1289,14 +1611,25 @@ impl<'ast> Checker<'ast> { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + use self::TypedExpression::*; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `*` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { + (Int(e1), Int(e2)) => Ok(IntExpression::Mult(box e1, box e2).into()), (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(FieldElementExpression::Mult(box e1, box e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { - Ok(UExpression::mult(e1, e2).into()) + Ok((e1 * e2).into()) } (t1, t2) => Err(ErrorInner { pos: Some(pos), @@ -1313,14 +1646,25 @@ impl<'ast> Checker<'ast> { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + use self::TypedExpression::*; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `/` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { - (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { + (Int(e1), Int(e2)) => Ok(IntExpression::Div(box e1, box e2).into()), + (FieldElement(e1), FieldElement(e2)) => { Ok(FieldElementExpression::Div(box e1, box e2).into()) } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { - Ok(UExpression::div(e1, e2).into()) + Ok((e1 / e2).into()) } (t1, t2) => Err(ErrorInner { pos: Some(pos), @@ -1337,11 +1681,19 @@ impl<'ast> Checker<'ast> { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `%` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.get_type() == e2.get_type() => { - Ok(UExpression::rem(e1, e2).into()) + Ok((e1 % e2).into()) } (t1, t2) => Err(ErrorInner { pos: Some(pos), @@ -1358,15 +1710,24 @@ impl<'ast> Checker<'ast> { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + let e1_checked = match FieldElementExpression::try_from_typed(e1_checked) { + Ok(e) => e.into(), + Err(e) => e, + }; + let e2_checked = match UExpression::try_from_typed(e2_checked, UBitwidth::B32) { + Ok(e) => e.into(), + Err(e) => e, + }; + match (e1_checked, e2_checked) { - (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => Ok( + (TypedExpression::FieldElement(e1), TypedExpression::Uint(e2)) => Ok( TypedExpression::FieldElement(FieldElementExpression::Pow(box e1, box e2)), ), (t1, t2) => Err(ErrorInner { pos: Some(pos), message: format!( - "Expected only field elements, found {}, {}", + "Expected `field` and `u32`, found {}, {}", t1.get_type(), t2.get_type() ), @@ -1414,36 +1775,44 @@ impl<'ast> Checker<'ast> { let consequence_checked = self.check_expression(consequence, module_id, &types)?; let alternative_checked = self.check_expression(alternative, module_id, &types)?; + let (consequence_checked, alternative_checked) = + TypedExpression::align_without_integers( + consequence_checked, + alternative_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", e1.get_type(), e2.get_type()), + })?; + match condition_checked { TypedExpression::Boolean(condition) => { - let consequence_type = consequence_checked.get_type(); - let alternative_type = alternative_checked.get_type(); - match consequence_type == alternative_type { - true => match (consequence_checked, alternative_checked) { - (TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => { - Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into()) - }, - (TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => { - Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into()) - }, - (TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => { - let inner_type = consequence.inner_type().clone(); - let size = consequence.size(); - Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into()) - }, - (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { - let ty = consequence.ty().clone(); - Ok(StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty).into()) - }, - (TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => { - let bitwidth = consequence.bitwidth(); - Ok(UExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(bitwidth).into()) - }, - _ => unreachable!("types should match here as we checked them explicitly") - } - false => Err(ErrorInner { + match (consequence_checked, alternative_checked) { + (TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => { + Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into()) + }, + (TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => { + Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into()) + }, + (TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => { + let inner_type = consequence.inner_type().clone(); + let size = consequence.size(); + Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into()) + }, + (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { + let ty = consequence.ty().clone(); + Ok(StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty).into()) + }, + (TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => { + let bitwidth = consequence.bitwidth(); + Ok(UExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(bitwidth).into()) + }, + (TypedExpression::Int(consequence), TypedExpression::Int(alternative)) => { + Ok(IntExpression::IfElse(box condition, box consequence, box alternative).into()) + }, + (c, a) => Err(ErrorInner { pos: Some(pos), - message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", consequence_type, alternative_type) + message: format!("{{consequence}} and {{alternative}} in `if/else` expression should have the same type, found {}, {}", c.get_type(), a.get_type()) }) } } @@ -1470,7 +1839,34 @@ impl<'ast> Checker<'ast> { Expression::U8Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(8).into()), Expression::U16Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(16).into()), Expression::U32Constant(n) => Ok(UExpressionInner::Value(n.into()).annotate(32).into()), - Expression::FunctionCall(fun_id, arguments) => { + Expression::FunctionCall(fun_id, generics, arguments) => { + // check the generic arguments, if any + let generics_checked: Option>>> = generics + .map(|generics| { + generics + .into_iter() + .map(|g| { + g.map(|g| { + let pos = g.pos(); + self.check_expression(g, module_id, &types).and_then(|g| { + UExpression::try_from_typed(g, UBitwidth::B32).map_err( + |e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected {} to be of type u32, found {}", + e, + e.get_type(), + ), + }, + ) + }) + }) + .transpose() + }) + .collect::>() + }) + .transpose()?; + // check the arguments let mut arguments_checked = vec![]; for arg in arguments { @@ -1485,55 +1881,87 @@ impl<'ast> Checker<'ast> { // outside of multidef, function calls must have a single return value // we use type inference to determine the type of the return, so we don't specify it - let query = FunctionQuery::new(&fun_id, &arguments_types, &vec![None]); + let query = FunctionQuery::new(&fun_id, &arguments_types, &[None]); - let f = self.find_function(&query); + let functions = self.find_functions(&query); - match f { + match functions.len() { // the function has to be defined - Some(f) => { + 1 => { + let mut functions = functions; + + let f = functions.pop().unwrap(); + + let signature = f.signature; + + let arguments_checked = arguments_checked.into_iter().zip(signature.inputs.clone()).map(|(a, t)| TypedExpression::align_to_type(a, t.into())).collect::, _>>().map_err(|e| ErrorInner { + pos: Some(pos), + message: format!("Expected function call argument to be of type {}, found {}", e.1, e.0) + })?; + + let output_types = signature.get_output_types(arguments_checked.iter().map(|a| a.get_type()).collect()).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Failed to infer value for generic parameter `{}`, try being more explicit by using an intermediate variable", + e, + ), + })?; + + let generics_checked = generics_checked.unwrap_or_else(|| vec![None; signature.generics.len()]); + // the return count has to be 1 - match f.signature.outputs.len() { - 1 => match &f.signature.outputs[0] { + match output_types.len() { + 1 => match &output_types[0] { + Type::Int => unreachable!(), Type::FieldElement => Ok(FieldElementExpression::FunctionCall( - FunctionKey { - id: f.id.clone(), - signature: f.signature.clone(), + DeclarationFunctionKey { + module: module_id.clone(), + id: f.id, + signature: signature.clone(), }, + generics_checked, arguments_checked, ) .into()), Type::Boolean => Ok(BooleanExpression::FunctionCall( - FunctionKey { - id: f.id.clone(), - signature: f.signature.clone(), + DeclarationFunctionKey { + module: module_id.clone(), + id: f.id, + signature: signature.clone(), }, + generics_checked, arguments_checked, ) .into()), Type::Uint(bitwidth) => Ok(UExpressionInner::FunctionCall( - FunctionKey { - id: f.id.clone(), - signature: f.signature.clone(), + DeclarationFunctionKey { + module: module_id.clone(), + id: f.id, + signature: signature.clone(), }, + generics_checked, arguments_checked, ) .annotate(*bitwidth) .into()), Type::Struct(members) => Ok(StructExpressionInner::FunctionCall( - FunctionKey { - id: f.id.clone(), - signature: f.signature.clone(), + DeclarationFunctionKey { + module: module_id.clone(), + id: f.id, + signature: signature.clone(), }, + generics_checked, arguments_checked, ) .annotate(members.clone()) .into()), Type::Array(array_type) => Ok(ArrayExpressionInner::FunctionCall( - FunctionKey { - id: f.id.clone(), - signature: f.signature.clone(), + DeclarationFunctionKey { + module: module_id.clone(), + id: f.id, + signature: signature.clone(), }, + generics_checked, arguments_checked, ) .annotate(*array_type.ty.clone(), array_type.size.clone()) @@ -1549,7 +1977,7 @@ impl<'ast> Checker<'ast> { }), } } - None => Err(ErrorInner { + 0 => Err(ErrorInner { pos: Some(pos), message: format!( @@ -1557,14 +1985,49 @@ impl<'ast> Checker<'ast> { fun_id, query ), }), + n => Err(ErrorInner { + pos: Some(pos), + message: format!("Ambiguous call to function {}, {} candidates were found. Please be more explicit.", fun_id, n) + }), } } Expression::Lt(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + })?; + match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::Lt(box e1, box e2).into()) + Ok(BooleanExpression::FieldLt(box e1, box e2).into()) + } + (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { + if e1.get_type() == e2.get_type() { + Ok(BooleanExpression::UintLt(box e1, box e2).into()) + } else { + Err(ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + }) + } } (e1, e2) => Err(ErrorInner { pos: Some(pos), @@ -1581,9 +2044,40 @@ impl<'ast> Checker<'ast> { Expression::Le(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + })?; + match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::Le(box e1, box e2).into()) + Ok(BooleanExpression::FieldLe(box e1, box e2).into()) + } + (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { + if e1.get_type() == e2.get_type() { + Ok(BooleanExpression::UintLe(box e1, box e2).into()) + } else { + Err(ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + }) + } } (e1, e2) => Err(ErrorInner { pos: Some(pos), @@ -1600,6 +2094,21 @@ impl<'ast> Checker<'ast> { Expression::Eq(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + })?; + match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { Ok(BooleanExpression::FieldEq(box e1, box e2).into()) @@ -1607,9 +2116,7 @@ impl<'ast> Checker<'ast> { (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::BoolEq(box e1, box e2).into()) } - (TypedExpression::Array(e1), TypedExpression::Array(e2)) - if e1.get_type() == e2.get_type() => - { + (TypedExpression::Array(e1), TypedExpression::Array(e2)) => { Ok(BooleanExpression::ArrayEq(box e1, box e2).into()) } (TypedExpression::Struct(e1), TypedExpression::Struct(e2)) @@ -1637,9 +2144,40 @@ impl<'ast> Checker<'ast> { Expression::Ge(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + })?; + match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::Ge(box e1, box e2).into()) + Ok(BooleanExpression::FieldGe(box e1, box e2).into()) + } + (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { + if e1.get_type() == e2.get_type() { + Ok(BooleanExpression::UintGe(box e1, box e2).into()) + } else { + Err(ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + }) + } } (e1, e2) => Err(ErrorInner { pos: Some(pos), @@ -1656,9 +2194,40 @@ impl<'ast> Checker<'ast> { Expression::Gt(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + })?; + match (e1_checked, e2_checked) { (TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => { - Ok(BooleanExpression::Gt(box e1, box e2).into()) + Ok(BooleanExpression::FieldGt(box e1, box e2).into()) + } + (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) => { + if e1.get_type() == e2.get_type() { + Ok(BooleanExpression::UintGt(box e1, box e2).into()) + } else { + Err(ErrorInner { + pos: Some(pos), + message: format!( + "Cannot compare {} of type {} to {} of type {}", + e1, + e1.get_type(), + e2, + e2.get_type() + ), + }) + } } (e1, e2) => Err(ErrorInner { pos: Some(pos), @@ -1676,145 +2245,96 @@ impl<'ast> Checker<'ast> { let array = self.check_expression(array, module_id, &types)?; match index { - RangeOrExpression::Range(r) => match array { - TypedExpression::Array(array) => { - let array_size = array.size(); - let inner_type = array.inner_type().clone(); + RangeOrExpression::Range(r) => { + match array { + TypedExpression::Array(array) => { + let array_size = array.size(); - // check that the bounds are valid expressions - let from = r - .value - .from - .map(|e| self.check_expression(e, module_id, &types)) - .unwrap_or(Ok(FieldElementExpression::Number(T::from(0)).into()))?; + let inner_type = array.inner_type().clone(); - let to = r - .value - .to - .map(|e| self.check_expression(e, module_id, &types)) - .unwrap_or(Ok(FieldElementExpression::Number(T::from( - array_size, - )) - .into()))?; + // check that the bounds are valid expressions + let from = r + .value + .from + .map(|e| self.check_expression(e, module_id, &types)) + .unwrap_or_else(|| Ok(UExpression::from(0u32).into()))?; - // check the bounds are field constants - // Note: it would be nice to allow any field expression, and check it's a constant after constant propagation, - // but it's tricky from a type perspective: the size of the slice changes the type of the resulting array, - // which doesn't work well with our static array approach. Enabling arrays to have unknown size introduces a lot - // of complexity in the compiler, as function selection in inlining requires knowledge of the array size, but - // determining array size potentially requires inlining and propagating. This suggests we would need semantic checking - // to happen iteratively with inlining and propagation, which we can't do now as we go from absy to typed_absy - let from = match from { - TypedExpression::FieldElement(FieldElementExpression::Number(n)) => Ok(n.to_dec_string().parse::().unwrap()), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected the lower bound of the range to be a constant field, found {}", - e - ), - }) - }?; + let to = r + .value + .to + .map(|e| self.check_expression(e, module_id, &types)) + .unwrap_or_else(|| Ok(array_size.clone().into()))?; - let to = match to { - TypedExpression::FieldElement(FieldElementExpression::Number(n)) => Ok(n.to_dec_string().parse::().unwrap()), - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Expected the higher bound of the range to be a constant field, found {}", - e - ), - }) - }?; + let from = UExpression::try_from_typed(from, UBitwidth::B32).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected the lower bound of the range to be a u32, found {} of type {}", + e, + e.get_type() + ), + })?; - match (from, to, array_size) { - (f, _, s) if f > s => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Lower range bound {} is out of array bounds [0, {}]", - f, s, - ), - }), - (_, t, s) if t > s => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Higher range bound {} is out of array bounds [0, {}]", - t, s, - ), - }), - (f, t, _) if f > t => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Lower range bound {} is larger than higher range bound {}", - f, t, - ), - }), - (f, t, _) => Ok(ArrayExpressionInner::Value( - (f..t) - .map(|i| match inner_type.clone() { - Type::FieldElement => FieldElementExpression::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Boolean => BooleanExpression::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .into(), - Type::Uint(bitwidth) => UExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(bitwidth) - .into(), - Type::Struct(struct_ty) => { - StructExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(struct_ty) - .into() - } - Type::Array(array_ty) => ArrayExpressionInner::Select( - box array.clone(), - box FieldElementExpression::Number(T::from(i)), - ) - .annotate(*array_ty.ty, array_ty.size) - .into(), - }) - .collect(), + let to = UExpression::try_from_typed(to, UBitwidth::B32).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected the upper bound of the range to be a u32, found {} of type {}", + e, + e.get_type() + ), + })?; + + Ok(ArrayExpressionInner::Slice( + box array, + box from.clone(), + box to.clone(), ) - .annotate(inner_type, t - f) - .into()), + .annotate(inner_type, UExpression::floor_sub(to, from)) + .into()) } - } - e => Err(ErrorInner { - pos: Some(pos), - message: format!( - "Cannot access slice of expression {} of type {}", - e, - e.get_type(), - ), - }), - }, - RangeOrExpression::Expression(e) => { - match (array, self.check_expression(e, module_id, &types)?) { - (TypedExpression::Array(a), TypedExpression::FieldElement(i)) => { - match a.inner_type().clone() { - Type::FieldElement => { - Ok(FieldElementExpression::select(a, i).into()) - } - Type::Uint(..) => Ok(UExpression::select(a, i).into()), - Type::Boolean => Ok(BooleanExpression::select(a, i).into()), - Type::Array(..) => Ok(ArrayExpression::select(a, i).into()), - Type::Struct(..) => Ok(StructExpression::select(a, i).into()), - } - } - (a, e) => Err(ErrorInner { + e => Err(ErrorInner { pos: Some(pos), message: format!( - "Cannot access element {} on expression of type {}", + "Cannot access slice of expression {} of type {}", e, + e.get_type(), + ), + }), + } + } + RangeOrExpression::Expression(index) => { + let index = self.check_expression(index, module_id, &types)?; + + let index = + UExpression::try_from_typed(index, UBitwidth::B32).map_err(|e| { + ErrorInner { + pos: Some(pos), + message: format!( + "Expected index to be of type u32, found {}", + e + ), + } + })?; + + match array { + TypedExpression::Array(a) => { + match a.inner_type().clone() { + Type::FieldElement => { + Ok(FieldElementExpression::select(a, index).into()) + } + Type::Uint(..) => Ok(UExpression::select(a, index).into()), + Type::Boolean => Ok(BooleanExpression::select(a, index).into()), + Type::Array(..) => Ok(ArrayExpression::select(a, index).into()), + Type::Struct(..) => Ok(StructExpression::select(a, index).into()), + Type::Int => unreachable!(), + } + } + a => Err(ErrorInner { + pos: Some(pos), + message: format!( + "Cannot access element as index {} of type {} on expression {} of type {}", + index, + index.get_type(), + a, a.get_type() ), }), @@ -1832,6 +2352,7 @@ impl<'ast> Checker<'ast> { match ty { Some(ty) => match ty { + Type::Int => unreachable!(), Type::FieldElement => { Ok(FieldElementExpression::member(s, id.to_string()).into()) } @@ -1850,7 +2371,17 @@ impl<'ast> Checker<'ast> { }, None => Err(ErrorInner { pos: Some(pos), - message: format!("{} doesn't have member {}", s.get_type(), id,), + message: format!( + "{} {{{}}} doesn't have member {}", + s.get_type(), + s.ty() + .members + .iter() + .map(|m| format!("{}: {}", m.id, m.ty)) + .collect::>() + .join(", "), + id, + ), }), } } @@ -1864,154 +2395,95 @@ impl<'ast> Checker<'ast> { }), } } - Expression::InlineArray(expressions) => { + Expression::InlineArray(expressions_or_spreads) => { // check each expression, getting its type - let mut expressions_checked = vec![]; - for e in expressions { + let mut expressions_or_spreads_checked = vec![]; + for e in expressions_or_spreads { let e_checked = self.check_spread_or_expression(e, module_id, &types)?; - expressions_checked.extend(e_checked); + expressions_or_spreads_checked.push(e_checked); } - // we infer the type to be the type of the first element - let inferred_type = expressions_checked.get(0).unwrap().get_type().clone(); - - match inferred_type { - Type::FieldElement => { - // we check all expressions have that same type - let mut unwrapped_expressions = vec![]; - - for e in expressions_checked { - let unwrapped_e = match e { - TypedExpression::FieldElement(e) => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - - message: format!( - "Expected {} to have type {}, but type is {}", - e, - inferred_type, - e.get_type() - ), - }), - }?; - unwrapped_expressions.push(unwrapped_e.into()); - } - - let size = unwrapped_expressions.len(); - - Ok(ArrayExpressionInner::Value(unwrapped_expressions) - .annotate(Type::FieldElement, size) - .into()) - } - Type::Boolean => { - // we check all expressions have that same type - let mut unwrapped_expressions = vec![]; - - for e in expressions_checked { - let unwrapped_e = match e { - TypedExpression::Boolean(e) => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - - message: format!( - "Expected {} to have type {}, but type is {}", - e, - inferred_type, - e.get_type() - ), - }), - }?; - unwrapped_expressions.push(unwrapped_e.into()); - } - - let size = unwrapped_expressions.len(); - - Ok(ArrayExpressionInner::Value(unwrapped_expressions) - .annotate(Type::Boolean, size) - .into()) - } - ty @ Type::Uint(..) => { - // we check all expressions have that same type - let mut unwrapped_expressions = vec![]; - - for e in expressions_checked { - let unwrapped_e = match e { - TypedExpression::Uint(e) if e.get_type() == ty => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - - message: format!( - "Expected {} to have type {}, but type is {}", - e, - ty, - e.get_type() - ), - }), - }?; - unwrapped_expressions.push(unwrapped_e.into()); - } - - let size = unwrapped_expressions.len(); - - Ok(ArrayExpressionInner::Value(unwrapped_expressions) - .annotate(ty, size) - .into()) - } - ty @ Type::Array(..) => { - // we check all expressions have that same type - let mut unwrapped_expressions = vec![]; - - for e in expressions_checked { - let unwrapped_e = match e { - TypedExpression::Array(e) if e.get_type() == ty => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - - message: format!( - "Expected {} to have type {}, but type is {}", - e, - ty, - e.get_type() - ), - }), - }?; - unwrapped_expressions.push(unwrapped_e.into()); - } - - let size = unwrapped_expressions.len(); - - Ok(ArrayExpressionInner::Value(unwrapped_expressions) - .annotate(ty, size) - .into()) - } - ty @ Type::Struct(..) => { - // we check all expressions have that same type - let mut unwrapped_expressions = vec![]; - - for e in expressions_checked { - let unwrapped_e = match e { - TypedExpression::Struct(e) if e.get_type() == ty => Ok(e), - e => Err(ErrorInner { - pos: Some(pos), - - message: format!( - "Expected {} to have type {}, but type is {}", - e, - ty, - e.get_type() - ), - }), - }?; - unwrapped_expressions.push(unwrapped_e.into()); - } - - let size = unwrapped_expressions.len(); - - Ok(ArrayExpressionInner::Value(unwrapped_expressions) - .annotate(ty, size) - .into()) - } + if expressions_or_spreads_checked.is_empty() { + return Err(ErrorInner { + pos: Some(pos), + message: "Empty arrays are not allowed".to_string(), + }); } + + // we infer the inner type to be the type of the first non-integer element + // if there was no such element, then the array only has integers and we use that as the inner type + let inferred_type = expressions_or_spreads_checked + .iter() + .filter_map(|e| match e.get_type() { + (Type::Int, _) => None, + (t, _) => Some(t), + }) + .next() + .unwrap_or(Type::Int); + + let unwrapped_expressions_or_spreads = match &inferred_type { + Type::Int => expressions_or_spreads_checked, + t => expressions_or_spreads_checked + .into_iter() + .map(|e| { + TypedExpressionOrSpread::align_to_type(e, t.clone()).map_err( + |(e, ty)| ErrorInner { + pos: Some(pos), + message: format!("Expected {} to have type {}", e, ty,), + }, + ) + }) + .collect::, _>>()?, + }; + + // the size of the inline array is the sum of the size of its elements. However expressed as a u32 expression, + // this value can be an tree of height n in the worst case, with n the size of the array (if all elements are + // simple values and not spreads, 1 + 1 + 1 + ... 1) + // To avoid that, we compute 2 sizes: the sum of all constant sizes as an u32 expression, and the + // sum of all non constant sizes as a u32 number. We then return the sum of the two as a u32 expression. + // `1 + 1 + ... + 1` is reduced to a single expression, which prevents this blowup + + let size: UExpression<'ast, T> = unwrapped_expressions_or_spreads + .iter() + .map(|e| e.size()) + .fold(None, |acc, e| match acc { + Some((c_acc, e_acc)) => match e.as_inner() { + UExpressionInner::Value(e) => Some(((c_acc + *e as u32), e_acc)), + _ => Some((c_acc, e_acc + e)), + }, + None => match e.as_inner() { + UExpressionInner::Value(e) => Some((*e as u32, 0u32.into())), + _ => Some((0u32, e)), + }, + }) + .map(|(c_size, e_size)| e_size + c_size.into()) + .unwrap_or_else(|| 0u32.into()); + + Ok( + ArrayExpressionInner::Value(unwrapped_expressions_or_spreads.into()) + .annotate(inferred_type, size) + .into(), + ) + } + Expression::ArrayInitializer(box e, box count) => { + let e = self.check_expression(e, module_id, &types)?; + let ty = e.get_type(); + + let count = self.check_expression(count, module_id, &types)?; + + let count = + UExpression::try_from_typed(count, UBitwidth::B32).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected array initializer count to be a u32, found {} of type {}", + e, + e.get_type(), + ), + })?; + + Ok(ArrayExpressionInner::Repeat(box e, box count.clone()) + .annotate(ty, count) + .into()) } Expression::InlineStruct(id, inline_members) => { let ty = self.check_type( @@ -2026,13 +2498,19 @@ impl<'ast> Checker<'ast> { // check that we provided the required number of values - if struct_type.len() != inline_members.len() { + if struct_type.members_count() != inline_members.len() { return Err(ErrorInner { pos: Some(pos), message: format!( - "Inline struct {} does not match {}", - Expression::InlineStruct(id.clone(), inline_members), - Type::Struct(struct_type) + "Inline struct {} does not match {} {{{}}}", + Expression::InlineStruct(id, inline_members), + Type::Struct(struct_type.clone()), + struct_type + .members + .iter() + .map(|m| format!("{}: {}", m.id, m.ty)) + .collect::>() + .join(", ") ), }); } @@ -2052,31 +2530,39 @@ impl<'ast> Checker<'ast> { Some(value) => { let expression_checked = self.check_expression(value, module_id, &types)?; - let checked_type = expression_checked.get_type(); - if checked_type != *member.ty { - return Err(ErrorInner { - pos: Some(pos), - message: format!( - "Member {} of struct {} has type {}, found {} of type {}", - member.id, - id.clone(), - member.ty, - expression_checked, - checked_type, - ), - }); - } else { - result.push(expression_checked.into()); - } + + let expression_checked = TypedExpression::align_to_type( + expression_checked, + *member.ty.clone(), + ) + .map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Member {} of struct {} has type {}, found {} of type {}", + member.id, + id.clone(), + e.1, + e.0, + e.0.get_type(), + ), + })?; + + result.push(expression_checked); } None => { return Err(ErrorInner { pos: Some(pos), message: format!( - "Member {} of struct {} not found in value {}", + "Member {} of struct {} {{{}}} not found in value {}", member.id, Type::Struct(struct_type.clone()), - Expression::InlineStruct(id.clone(), inline_members), + struct_type + .members + .iter() + .map(|m| format!("{}: {}", m.id, m.ty)) + .collect::>() + .join(", "), + Expression::InlineStruct(id, inline_members), ), }) } @@ -2090,7 +2576,23 @@ impl<'ast> Checker<'ast> { Expression::And(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!( + "Cannot apply boolean operators to {} and {}", + e1.get_type(), + e2.get_type() + ), + })?; + match (e1_checked, e2_checked) { + (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { + Ok(IntExpression::And(box e1, box e2).into()) + } (TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => { Ok(BooleanExpression::And(box e1, box e2).into()) } @@ -2098,7 +2600,7 @@ impl<'ast> Checker<'ast> { pos: Some(pos), message: format!( - "cannot apply boolean operators to {} and {}", + "Cannot apply boolean operators to {} and {}", e1.get_type(), e2.get_type() ), @@ -2114,25 +2616,35 @@ impl<'ast> Checker<'ast> { } (e1, e2) => Err(ErrorInner { pos: Some(pos), - - message: format!("cannot compare {} to {}", e1.get_type(), e2.get_type()), + message: format!( + "Cannot apply `||` to {}, {}", + e1.get_type(), + e2.get_type() + ), }), } } Expression::LeftShift(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; - match (e1_checked, e2_checked) { - (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) - if e2.bitwidth == UBitwidth::B32 => - { - Ok(UExpression::left_shift(e1, e2).into()) - } - (e1, e2) => Err(ErrorInner { + let e1 = self.check_expression(e1, module_id, &types)?; + let e2 = self.check_expression(e2, module_id, &types)?; + + let e2 = + UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected the left shift right operand to have type `u32`, found {}", + e + ), + })?; + + match e1 { + TypedExpression::Int(e1) => Ok(IntExpression::LeftShift(box e1, box e2).into()), + TypedExpression::Uint(e1) => Ok(UExpression::left_shift(e1, e2).into()), + e1 => Err(ErrorInner { pos: Some(pos), message: format!( - "cannot left-shift {} by {}", + "Cannot left-shift {} by {}", e1.get_type(), e2.get_type() ), @@ -2140,19 +2652,28 @@ impl<'ast> Checker<'ast> { } } Expression::RightShift(box e1, box e2) => { - let e1_checked = self.check_expression(e1, module_id, &types)?; - let e2_checked = self.check_expression(e2, module_id, &types)?; - match (e1_checked, e2_checked) { - (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) - if e2.bitwidth == UBitwidth::B32 => - { - Ok(UExpression::right_shift(e1, e2).into()) + let e1 = self.check_expression(e1, module_id, &types)?; + let e2 = self.check_expression(e2, module_id, &types)?; + + let e2 = + UExpression::try_from_typed(e2, UBitwidth::B32).map_err(|e| ErrorInner { + pos: Some(pos), + message: format!( + "Expected the right shift right operand to be of type `u32`, found {}", + e + ), + })?; + + match e1 { + TypedExpression::Int(e1) => { + Ok(IntExpression::RightShift(box e1, box e2).into()) } - (e1, e2) => Err(ErrorInner { + TypedExpression::Uint(e1) => Ok(UExpression::right_shift(e1, e2).into()), + e1 => Err(ErrorInner { pos: Some(pos), message: format!( - "cannot right-shift {} by {}", + "Cannot right-shift {} by {}", e1.get_type(), e2.get_type() ), @@ -2162,7 +2683,19 @@ impl<'ast> Checker<'ast> { Expression::BitOr(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `|` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { + (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { + Ok(IntExpression::Or(box e1, box e2).into()) + } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => { @@ -2182,7 +2715,19 @@ impl<'ast> Checker<'ast> { Expression::BitAnd(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `&` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { + (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { + Ok(IntExpression::And(box e1, box e2).into()) + } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => { @@ -2202,7 +2747,19 @@ impl<'ast> Checker<'ast> { Expression::BitXor(box e1, box e2) => { let e1_checked = self.check_expression(e1, module_id, &types)?; let e2_checked = self.check_expression(e2, module_id, &types)?; + + let (e1_checked, e2_checked) = TypedExpression::align_without_integers( + e1_checked, e2_checked, + ) + .map_err(|(e1, e2)| ErrorInner { + pos: Some(pos), + message: format!("Cannot apply `^` to {}, {}", e1.get_type(), e2.get_type()), + })?; + match (e1_checked, e2_checked) { + (TypedExpression::Int(e1), TypedExpression::Int(e2)) => { + Ok(IntExpression::Xor(box e1, box e2).into()) + } (TypedExpression::Uint(e1), TypedExpression::Uint(e2)) if e1.bitwidth() == e2.bitwidth() => { @@ -2222,19 +2779,21 @@ impl<'ast> Checker<'ast> { Expression::Not(box e) => { let e_checked = self.check_expression(e, module_id, &types)?; match e_checked { + TypedExpression::Int(e) => Ok(IntExpression::Not(box e).into()), TypedExpression::Boolean(e) => Ok(BooleanExpression::Not(box e).into()), - TypedExpression::Uint(e) => Ok(UExpression::not(e).into()), + TypedExpression::Uint(e) => Ok((!e).into()), e => Err(ErrorInner { pos: Some(pos), - - message: format!("cannot negate {}", e.get_type()), + message: format!("Cannot negate {}", e.get_type()), }), } } } } - fn get_scope(&self, variable_name: &'ast str) -> Option<&'ast ScopedVariable> { + fn get_scope<'a>(&'a self, variable_name: &'ast str) -> Option<&'a ScopedVariable<'ast, T>> { + // we take advantage of the fact that all ScopedVariable of the same identifier hash to the same thing + // and by construction only one can be in the set self.scope.get(&ScopedVariable { id: Variable::with_id_and_type( crate::typed_absy::Identifier::from(variable_name), @@ -2244,14 +2803,14 @@ impl<'ast> Checker<'ast> { }) } - fn insert_into_scope(&mut self, v: Variable<'ast>) -> bool { + fn insert_into_scope>>(&mut self, v: U) -> bool { self.scope.insert(ScopedVariable { - id: v, + id: v.into(), level: self.level, }) } - fn find_function(&self, query: &FunctionQuery<'ast>) -> Option> { + fn find_functions(&self, query: &FunctionQuery<'ast, T>) -> Vec> { query.match_funcs(&self.functions) } @@ -2272,7 +2831,6 @@ mod tests { use super::*; use crate::absy; use crate::typed_absy; - use num_bigint::BigUint; use zokrates_field::Bn128Field; const MODULE_ID: &str = ""; @@ -2284,27 +2842,31 @@ mod tests { #[test] fn field_in_range() { + // The value of `P - 1` is a valid field literal + let types = HashMap::new(); let module_id = "".into(); let expr = Expression::FieldConstant(BigUint::from(Bn128Field::max_value().to_biguint())) .mock(); - assert!(Checker::new() - .check_expression::(expr, &module_id, &types) + assert!(Checker::::new() + .check_expression(expr, &module_id, &types) .is_ok()); } #[test] fn field_overflow() { + // the value of `P` is an invalid field literal + let types = HashMap::new(); let module_id = "".into(); let value = Bn128Field::max_value().to_biguint().add(1u32); - let expr = Expression::FieldConstant(BigUint::from(value)).mock(); + let expr = Expression::FieldConstant(value).mock(); - assert!(Checker::new() - .check_expression::(expr, &module_id, &types) + assert!(Checker::::new() + .check_expression(expr, &module_id, &types) .is_err()); } } @@ -2315,19 +2877,24 @@ mod tests { #[test] fn element_type_mismatch() { + // having different types in an array isn't allowed + // in the case of arrays, lengths do *not* have to match, as at this point they can be + // generic, so we cannot tell yet + let types = HashMap::new(); let module_id = "".into(); // [3, true] let a = Expression::InlineArray(vec![ - Expression::FieldConstant(BigUint::from(3u32)).mock().into(), + Expression::IntConstant(3usize.into()).mock().into(), Expression::BooleanConstant(true).mock().into(), ]) .mock(); - assert!(Checker::new() - .check_expression::(a, &module_id, &types) + assert!(Checker::::new() + .check_expression(a, &module_id, &types) .is_err()); - // [[0], [0, 0]] + // [[0f], [0f, 0f]] + // accepted at this stage, as we do not check array lengths (as they can be variable) let a = Expression::InlineArray(vec![ Expression::InlineArray(vec![Expression::FieldConstant(BigUint::from(0u32)) .mock() @@ -2342,11 +2909,11 @@ mod tests { .into(), ]) .mock(); - assert!(Checker::new() - .check_expression::(a, &module_id, &types) - .is_err()); + assert!(Checker::::new() + .check_expression(a, &module_id, &types) + .is_ok()); - // [[0], true] + // [[0f], true] let a = Expression::InlineArray(vec![ Expression::InlineArray(vec![Expression::FieldConstant(BigUint::from(0u32)) .mock() @@ -2358,64 +2925,64 @@ mod tests { .into(), ]) .mock(); - assert!(Checker::new() - .check_expression::(a, &module_id, &types) + assert!(Checker::::new() + .check_expression(a, &module_id, &types) .is_err()); } } + /// Helper function to create ((): return) + fn function0() -> FunctionNode<'static> { + let statements = vec![Statement::Return( + ExpressionList { + expressions: vec![], + } + .mock(), + ) + .mock()]; + + let arguments = vec![]; + + let signature = UnresolvedSignature::new(); + + Function { + arguments, + statements, + signature, + } + .mock() + } + + /// Helper function to create ((private field a): return) + fn function1() -> FunctionNode<'static> { + let statements = vec![Statement::Return( + ExpressionList { + expressions: vec![], + } + .mock(), + ) + .mock()]; + + let arguments = vec![absy::Parameter { + id: absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + private: true, + } + .mock()]; + + let signature = + UnresolvedSignature::new().inputs(vec![UnresolvedType::FieldElement.mock()]); + + Function { + arguments, + statements, + signature, + } + .mock() + } + mod symbols { use super::*; - /// Helper function to create ((): return) - fn function0() -> FunctionNode<'static> { - let statements: Vec = vec![Statement::Return( - ExpressionList { - expressions: vec![], - } - .mock(), - ) - .mock()]; - - let arguments = vec![]; - - let signature = UnresolvedSignature::new(); - - Function { - arguments, - statements, - signature, - } - .mock() - } - - /// Helper function to create ((private field a): return) - fn function1() -> FunctionNode<'static> { - let statements: Vec = vec![Statement::Return( - ExpressionList { - expressions: vec![], - } - .mock(), - ) - .mock()]; - - let arguments = vec![absy::Parameter { - id: absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), - private: true, - } - .mock()]; - - let signature = - UnresolvedSignature::new().inputs(vec![UnresolvedType::FieldElement.mock()]); - - Function { - arguments, - statements, - signature, - } - .mock() - } - fn struct0() -> StructDefinitionNode<'static> { StructDefinition { fields: vec![] }.mock() } @@ -2423,7 +2990,7 @@ mod tests { fn struct1() -> StructDefinitionNode<'static> { StructDefinition { fields: vec![StructDefinitionField { - id: "foo".into(), + id: "foo", ty: UnresolvedType::FieldElement.mock(), } .mock()], @@ -2437,14 +3004,47 @@ mod tests { let mut unifier = SymbolUnifier::default(); + // the `foo` type assert!(unifier.insert_type("foo")); + // the `foo` type annot be declared a second time assert!(!unifier.insert_type("foo")); - assert!(!unifier.insert_function("foo", Signature::new())); - assert!(unifier.insert_function("bar", Signature::new())); - assert!(!unifier.insert_function("bar", Signature::new())); - assert!( - unifier.insert_function("bar", Signature::new().inputs(vec![Type::FieldElement])) - ); + // the `foo` function cannot be declared as the name is already taken by a type + assert!(!unifier.insert_function("foo", DeclarationSignature::new())); + // the `bar` type + assert!(unifier.insert_function("bar", DeclarationSignature::new())); + // a second `bar` function of the same signature cannot be declared + assert!(!unifier.insert_function("bar", DeclarationSignature::new())); + // a second `bar` function of a different signature can be declared + assert!(unifier.insert_function( + "bar", + DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]) + )); + // a second `bar` function with a generic parameter, which *could* conflict with an existing one should not be allowed + assert!(!unifier.insert_function( + "bar", + DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::FieldElement]) + )); + // a `bar` function with a different signature + assert!(unifier.insert_function( + "bar", + DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + "K" + ))]) + )); + // a `bar` function with a different signature, but which could conflict with the previous one + assert!(!unifier.insert_function( + "bar", + DeclarationSignature::new().inputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + 42u32 + ))]) + )); + // a `bar` type isn't allowed as the name is already taken by at least one function assert!(!unifier.insert_type("bar")); } @@ -2483,17 +3083,18 @@ mod tests { .collect(), ); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!(checker.check_module(&"bar".into(), &mut state), Ok(())); assert_eq!( state.typed_modules.get(&PathBuf::from("bar")), Some(&TypedModule { functions: vec![( - FunctionKey::with_id("main").signature(Signature::new()), + DeclarationFunctionKey::with_location("bar", "main") + .signature(DeclarationSignature::new()), TypedFunctionSymbol::There( - FunctionKey::with_id("main").signature(Signature::new()), - "foo".into() + DeclarationFunctionKey::with_location("foo", "main") + .signature(DeclarationSignature::new()), ) )] .into_iter() @@ -2533,7 +3134,7 @@ mod tests { .collect(), ); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker .check_module(&PathBuf::from(MODULE_ID).into(), &mut state) @@ -2544,6 +3145,197 @@ mod tests { ); } + #[test] + fn duplicate_function_declaration_generic() { + // def foo

(private field[P] a): + // return + // def foo(private field[3] a): + // return + // + // should fail as P could be equal to 3 + + let mut f0 = function0(); + + f0.value.arguments = vec![absy::Parameter::private( + absy::Variable::new( + "a", + UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::Identifier("P").mock(), + ) + .mock(), + ) + .mock(), + ) + .mock()]; + f0.value.signature = UnresolvedSignature::new() + .generics(vec!["P".mock()]) + .inputs(vec![UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::Identifier("P").mock(), + ) + .mock()]); + + let mut f1 = function0(); + f1.value.arguments = vec![absy::Parameter::private( + absy::Variable::new( + "a", + UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::U32Constant(3).mock(), + ) + .mock(), + ) + .mock(), + ) + .mock()]; + f1.value.signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::U32Constant(3).mock(), + ) + .mock()]); + + let module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(f0), + } + .mock(), + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(f1), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new( + vec![(PathBuf::from(MODULE_ID).into(), module)] + .into_iter() + .collect(), + ); + + let mut checker: Checker = Checker::new(); + assert_eq!( + checker + .check_module(&PathBuf::from(MODULE_ID), &mut state) + .unwrap_err()[0] + .inner + .message, + "foo conflicts with another symbol" + ); + } + + mod generics { + use super::*; + + #[test] + fn unused_generic() { + // def foo

(): + // return + // def main(): + // return + // + // should succeed + + let mut foo = function0(); + + foo.value.signature = UnresolvedSignature::new().generics(vec!["P".mock()]); + + let module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(foo), + } + .mock(), + SymbolDeclaration { + id: "main", + symbol: Symbol::HereFunction(function0()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new( + vec![(PathBuf::from(MODULE_ID).into(), module)] + .into_iter() + .collect(), + ); + + let mut checker: Checker = Checker::new(); + assert!(checker + .check_module(&PathBuf::from(MODULE_ID).into(), &mut state) + .is_ok()); + } + + #[test] + fn undeclared_generic() { + // def foo(field[P] a): + // return + // def main(): + // return + // + // should fail + + let mut foo = function0(); + + foo.value.arguments = vec![absy::Parameter::private( + absy::Variable::new( + "a", + UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::Identifier("P").mock(), + ) + .mock(), + ) + .mock(), + ) + .mock()]; + foo.value.signature = + UnresolvedSignature::new().inputs(vec![UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::Identifier("P").mock(), + ) + .mock()]); + + let module = Module { + symbols: vec![ + SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(foo), + } + .mock(), + SymbolDeclaration { + id: "main", + symbol: Symbol::HereFunction(function0()), + } + .mock(), + ], + imports: vec![], + }; + + let mut state = State::new( + vec![(PathBuf::from(MODULE_ID).into(), module)] + .into_iter() + .collect(), + ); + + let mut checker: Checker = Checker::new(); + assert_eq!( + checker + .check_module(&PathBuf::from(MODULE_ID).into(), &mut state) + .unwrap_err()[0] + .inner + .message, + "Undeclared generic parameter in function definition: `P` isn\'t declared as a generic constant" + ); + } + } + #[test] fn overloaded_function_declaration() { // def foo(): @@ -2575,7 +3367,7 @@ mod tests { .collect(), ); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker.check_module(&PathBuf::from(MODULE_ID), &mut state), Ok(()) @@ -2585,15 +3377,19 @@ mod tests { .get(&PathBuf::from(MODULE_ID)) .unwrap() .functions - .contains_key(&FunctionKey::with_id("foo").signature(Signature::new()))); + .contains_key( + &DeclarationFunctionKey::with_location(MODULE_ID, "foo") + .signature(DeclarationSignature::new()) + )); assert!(state .typed_modules .get(&PathBuf::from(MODULE_ID)) .unwrap() .functions .contains_key( - &FunctionKey::with_id("foo") - .signature(Signature::new().inputs(vec![Type::FieldElement])) + &DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature( + DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]) + ) )) } @@ -2623,7 +3419,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker .check_module(&"main".into(), &mut state) @@ -2661,7 +3457,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker .check_module(&"main".into(), &mut state) @@ -2715,7 +3511,7 @@ mod tests { .collect(), ); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker .check_module(&PathBuf::from(MODULE_ID), &mut state) @@ -2766,7 +3562,7 @@ mod tests { .collect(), ); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker .check_module(&PathBuf::from(MODULE_ID), &mut state) @@ -2778,15 +3574,76 @@ mod tests { } } - pub fn new_with_args<'ast>( - scope: HashSet>, + pub fn new_with_args<'ast, T: Field>( + scope: HashSet>, level: usize, - functions: HashSet>, - ) -> Checker<'ast> { + functions: HashSet>, + ) -> Checker<'ast, T> { Checker { scope, functions, level, + return_types: None, + } + } + + // checking function signatures + mod signature { + use super::*; + + #[test] + fn undeclared_generic() { + let signature = UnresolvedSignature::new().inputs(vec![UnresolvedType::Array( + box UnresolvedType::FieldElement.mock(), + Expression::Identifier("K").mock(), + ) + .mock()]); + assert_eq!(Checker::::new().check_signature(signature, &MODULE_ID.into(), &TypeMap::default()), Err(vec![ErrorInner { + pos: Some((Position::mock(), Position::mock())), + message: "Undeclared generic parameter in function definition: `K` isn\'t declared as a generic constant".to_string() + }])); + } + + #[test] + fn success() { + // (field[L][K]) -> field[L][K] + + let signature = UnresolvedSignature::new() + .generics(vec!["K".mock(), "L".mock(), "M".mock()]) + .inputs(vec![UnresolvedType::Array( + box UnresolvedType::Array( + box UnresolvedType::FieldElement.mock(), + Expression::Identifier("K").mock(), + ) + .mock(), + Expression::Identifier("L").mock(), + ) + .mock()]) + .outputs(vec![UnresolvedType::Array( + box UnresolvedType::Array( + box UnresolvedType::FieldElement.mock(), + Expression::Identifier("L").mock(), + ) + .mock(), + Expression::Identifier("K").mock(), + ) + .mock()]); + assert_eq!( + Checker::::new().check_signature( + signature, + &MODULE_ID.into(), + &TypeMap::default() + ), + Ok(DeclarationSignature::new() + .inputs(vec![DeclarationType::array(( + DeclarationType::array((DeclarationType::FieldElement, "K")), + "L" + ))]) + .outputs(vec![DeclarationType::array(( + DeclarationType::array((DeclarationType::FieldElement, "L")), + "K" + ))])) + ); } } @@ -2803,9 +3660,9 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_statement::(statement, &module_id, &types), + checker.check_statement(statement, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"b\" is undefined".into() @@ -2835,9 +3692,9 @@ mod tests { id: Variable::field_element("b"), level: 0, }); - let mut checker = new_with_args(scope, 1, HashSet::new()); + let mut checker: Checker = new_with_args(scope, 1, HashSet::new()); assert_eq!( - checker.check_statement::(statement, &module_id, &types), + checker.check_statement(statement, &module_id, &types), Ok(TypedStatement::Definition( TypedAssignee::Identifier(typed_absy::Variable::field_element("a")), FieldElementExpression::Identifier("b".into()).into() @@ -2861,7 +3718,7 @@ mod tests { .mock(), Statement::Definition( Assignee::Identifier("a").mock(), - Expression::FieldConstant(BigUint::from(1u32)).mock(), + Expression::IntConstant(1usize.into()).mock(), ) .mock(), Statement::Return( @@ -2891,10 +3748,9 @@ mod tests { let bar = Function { arguments: bar_args, statements: bar_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); @@ -2918,7 +3774,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( checker.check_module(&"main".into(), &mut state), Err(vec![Error { @@ -2950,7 +3806,7 @@ mod tests { .mock(), Statement::Definition( Assignee::Identifier("a").mock(), - Expression::FieldConstant(BigUint::from(1u32)).mock(), + Expression::IntConstant(1usize.into()).mock(), ) .mock(), Statement::Return( @@ -2977,7 +3833,7 @@ mod tests { .mock(), Statement::Definition( Assignee::Identifier("a").mock(), - Expression::FieldConstant(BigUint::from(2u32)).mock(), + Expression::IntConstant(2usize.into()).mock(), ) .mock(), Statement::Return( @@ -2998,7 +3854,7 @@ mod tests { let main_args = vec![]; let main_statements = vec![Statement::Return( ExpressionList { - expressions: vec![Expression::FieldConstant(BigUint::from(1u32)).mock()], + expressions: vec![Expression::IntConstant(1usize.into()).mock()], } .mock(), ) @@ -3007,10 +3863,9 @@ mod tests { let main = Function { arguments: main_args, statements: main_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); @@ -3039,7 +3894,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert!(checker.check_module(&"main".into(), &mut state).is_ok()); } @@ -3050,11 +3905,11 @@ mod tests { // endfor // return i // should fail - let foo_statements = vec![ + let foo_statements: Vec = vec![ Statement::For( - absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(), - Expression::FieldConstant(BigUint::from(0u32)).mock(), - Expression::FieldConstant(BigUint::from(10u32)).mock(), + absy::Variable::new("i", UnresolvedType::Uint(32).mock()).mock(), + Expression::IntConstant(0usize.into()).mock(), + Expression::IntConstant(10usize.into()).mock(), vec![], ) .mock(), @@ -3069,19 +3924,18 @@ mod tests { let foo = Function { arguments: vec![], statements: foo_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); let types = HashMap::new(); let module_id = "".into(); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_function::(foo, &module_id, &types), + checker.check_function(foo, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"i\" is undefined".into() @@ -3092,7 +3946,7 @@ mod tests { #[test] fn for_index_in_for() { // def foo(): - // for i in 0..10 do + // for field i in 0..10 do // a = i // endfor // return @@ -3100,7 +3954,7 @@ mod tests { let for_statements = vec![ Statement::Declaration( - absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + absy::Variable::new("a", UnresolvedType::Uint(32).mock()).mock(), ) .mock(), Statement::Definition( @@ -3112,9 +3966,9 @@ mod tests { let foo_statements = vec![ Statement::For( - absy::Variable::new("i", UnresolvedType::FieldElement.mock()).mock(), - Expression::FieldConstant(BigUint::from(0u32)).mock(), - Expression::FieldConstant(BigUint::from(10u32)).mock(), + absy::Variable::new("i", UnresolvedType::Uint(32).mock()).mock(), + Expression::IntConstant(0usize.into()).mock(), + Expression::IntConstant(10usize.into()).mock(), for_statements, ) .mock(), @@ -3128,18 +3982,20 @@ mod tests { ]; let for_statements_checked = vec![ - TypedStatement::Declaration(typed_absy::Variable::field_element("a")), + TypedStatement::Declaration(typed_absy::Variable::uint("a", UBitwidth::B32)), TypedStatement::Definition( - TypedAssignee::Identifier(typed_absy::Variable::field_element("a")), - FieldElementExpression::Identifier("i".into()).into(), + TypedAssignee::Identifier(typed_absy::Variable::uint("a", UBitwidth::B32)), + UExpressionInner::Identifier("i".into()) + .annotate(UBitwidth::B32) + .into(), ), ]; let foo_statements_checked = vec![ TypedStatement::For( - typed_absy::Variable::field_element("i"), - FieldElementExpression::Number(Bn128Field::from(0u32)), - FieldElementExpression::Number(Bn128Field::from(10u32)), + typed_absy::Variable::uint("i", UBitwidth::B32), + 0u32.into(), + 10u32.into(), for_statements_checked, ), TypedStatement::Return(vec![]), @@ -3155,15 +4011,15 @@ mod tests { let foo_checked = TypedFunction { arguments: vec![], statements: foo_statements_checked, - signature: Signature::new(), + signature: DeclarationSignature::default(), }; let types = HashMap::new(); let module_id = "".into(); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_function::(foo, &module_id, &types), + checker.check_function(foo, &module_id, &types), Ok(foo_checked) ); } @@ -3182,7 +4038,7 @@ mod tests { .mock(), Statement::MultipleDefinition( vec![Assignee::Identifier("a").mock()], - Expression::FunctionCall("foo", vec![]).mock(), + Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), Statement::Return( @@ -3194,12 +4050,13 @@ mod tests { .mock(), ]; - let foo = FunctionKey { + let foo = DeclarationFunctionKey { + module: "main".into(), id: "foo", - signature: Signature { - inputs: vec![], - outputs: vec![Type::FieldElement, Type::FieldElement], - }, + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![ + DeclarationType::FieldElement, + DeclarationType::FieldElement, + ]), }; let functions = vec![foo].into_iter().collect(); @@ -3214,9 +4071,9 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker = new_with_args(HashSet::new(), 0, functions); + let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function::(bar, &module_id, &types), + checker.check_function(bar, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: @@ -3237,8 +4094,8 @@ mod tests { let bar_statements: Vec = vec![ Statement::Assertion( Expression::Eq( - box Expression::FieldConstant(BigUint::from(2u32)).mock(), - box Expression::FunctionCall("foo", vec![]).mock(), + box Expression::IntConstant(2usize.into()).mock(), + box Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), ) @@ -3252,12 +4109,13 @@ mod tests { .mock(), ]; - let foo = FunctionKey { + let foo = DeclarationFunctionKey { + module: "main".into(), id: "foo", - signature: Signature { - inputs: vec![], - outputs: vec![Type::FieldElement, Type::FieldElement], - }, + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![ + DeclarationType::FieldElement, + DeclarationType::FieldElement, + ]), }; let functions = vec![foo].into_iter().collect(); @@ -3265,19 +4123,16 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![], - }, + signature: UnresolvedSignature::new(), } .mock(); let types = HashMap::new(); let module_id = "".into(); - let mut checker = new_with_args(HashSet::new(), 0, functions); + let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function::(bar, &module_id, &types), + checker.check_function(bar, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Function definition for function foo with signature () -> _ not found." @@ -3299,7 +4154,7 @@ mod tests { .mock(), Statement::MultipleDefinition( vec![Assignee::Identifier("a").mock()], - Expression::FunctionCall("foo", vec![]).mock(), + Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), Statement::Return( @@ -3321,9 +4176,9 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function::(bar, &module_id, &types), + checker.check_function(bar, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), @@ -3346,8 +4201,8 @@ mod tests { let foo_statements: Vec = vec![Statement::Return( ExpressionList { expressions: vec![ - Expression::FieldConstant(BigUint::from(1u32)).mock(), - Expression::FieldConstant(BigUint::from(2u32)).mock(), + Expression::IntConstant(1usize.into()).mock(), + Expression::IntConstant(2usize.into()).mock(), ], } .mock(), @@ -3361,13 +4216,12 @@ mod tests { } .mock()], statements: foo_statements, - signature: UnresolvedSignature { - inputs: vec![UnresolvedType::FieldElement.mock()], - outputs: vec![ + signature: UnresolvedSignature::new() + .inputs(vec![UnresolvedType::FieldElement.mock()]) + .outputs(vec![ UnresolvedType::FieldElement.mock(), UnresolvedType::FieldElement.mock(), - ], - }, + ]), } .mock(); @@ -3385,12 +4239,13 @@ mod tests { Assignee::Identifier("a").mock(), Assignee::Identifier("b").mock(), ], - Expression::FunctionCall("foo", vec![Expression::Identifier("x").mock()]).mock(), + Expression::FunctionCall("foo", None, vec![Expression::Identifier("x").mock()]) + .mock(), ) .mock(), Statement::Return( ExpressionList { - expressions: vec![Expression::FieldConstant(BigUint::from(1u32)).mock()], + expressions: vec![Expression::IntConstant(1usize.into()).mock()], } .mock(), ) @@ -3400,10 +4255,9 @@ mod tests { let main = Function { arguments: vec![], statements: main_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); @@ -3426,7 +4280,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( checker.check_module(&"main".into(), &mut state), Err(vec![Error { @@ -3451,8 +4305,8 @@ mod tests { let foo_statements: Vec = vec![Statement::Return( ExpressionList { expressions: vec![ - Expression::FieldConstant(BigUint::from(1u32)).mock(), - Expression::FieldConstant(BigUint::from(2u32)).mock(), + Expression::IntConstant(1usize.into()).mock(), + Expression::IntConstant(2usize.into()).mock(), ], } .mock(), @@ -3462,13 +4316,10 @@ mod tests { let foo = Function { arguments: vec![], statements: foo_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![ - UnresolvedType::FieldElement.mock(), - UnresolvedType::FieldElement.mock(), - ], - }, + signature: UnresolvedSignature::new().inputs(vec![]).outputs(vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::FieldElement.mock(), + ]), } .mock(); @@ -3478,7 +4329,7 @@ mod tests { Assignee::Identifier("a").mock(), Assignee::Identifier("b").mock(), ], - Expression::FunctionCall("foo", vec![]).mock(), + Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), Statement::Return( @@ -3493,10 +4344,7 @@ mod tests { let main = Function { arguments: vec![], statements: main_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![], - }, + signature: UnresolvedSignature::new().inputs(vec![]).outputs(vec![]), } .mock(); @@ -3519,7 +4367,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( checker.check_module(&"main".into(), &mut state), Err(vec![ @@ -3553,7 +4401,7 @@ mod tests { let foo_statements: Vec = vec![Statement::Return( ExpressionList { - expressions: vec![Expression::FieldConstant(BigUint::from(1u32)).mock()], + expressions: vec![Expression::IntConstant(1usize.into()).mock()], } .mock(), ) @@ -3562,10 +4410,9 @@ mod tests { let foo = Function { arguments: vec![], statements: foo_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); @@ -3573,15 +4420,19 @@ mod tests { Statement::Declaration( absy::Variable::new( "a", - UnresolvedType::array(UnresolvedType::FieldElement.mock(), 1).mock(), + UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::IntConstant(1usize.into()).mock(), + ) + .mock(), ) .mock(), ) .mock(), Statement::Definition( - Assignee::Identifier("a".into()).mock(), + Assignee::Identifier("a").mock(), Expression::InlineArray(vec![absy::SpreadOrExpression::Expression( - Expression::FieldConstant(BigUint::from(0u32)).mock(), + Expression::IntConstant(0usize.into()).mock(), )]) .mock(), ) @@ -3590,11 +4441,11 @@ mod tests { vec![Assignee::Select( box Assignee::Identifier("a").mock(), box RangeOrExpression::Expression( - absy::Expression::FieldConstant(BigUint::from(0u32)).mock(), + absy::Expression::IntConstant(0usize.into()).mock(), ), ) .mock()], - Expression::FunctionCall("foo", vec![]).mock(), + Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), Statement::Return( @@ -3609,10 +4460,7 @@ mod tests { let main = Function { arguments: vec![], statements: main_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![], - }, + signature: UnresolvedSignature::new().inputs(vec![]).outputs(vec![]), } .mock(); @@ -3635,7 +4483,7 @@ mod tests { let mut state = State::::new(vec![("main".into(), module)].into_iter().collect()); - let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert!(checker.check_module(&"main".into(), &mut state).is_ok()); } @@ -3645,11 +4493,12 @@ mod tests { // 1 == foo() // return // should fail + let bar_statements: Vec = vec![ Statement::Assertion( Expression::Eq( - box Expression::FieldConstant(BigUint::from(1u32)).mock(), - box Expression::FunctionCall("foo", vec![]).mock(), + box Expression::IntConstant(1usize.into()).mock(), + box Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), ) @@ -3673,9 +4522,9 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function::(bar, &module_id, &types), + checker.check_function(bar, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), @@ -3704,22 +4553,19 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![ - UnresolvedType::FieldElement.mock(), - UnresolvedType::FieldElement.mock(), - ], - }, + signature: UnresolvedSignature::new().inputs(vec![]).outputs(vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::FieldElement.mock(), + ]), } .mock(); let types = HashMap::new(); let module_id = "".into(); - let mut checker = new_with_args(HashSet::new(), 0, HashSet::new()); + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); assert_eq!( - checker.check_function::(bar, &module_id, &types), + checker.check_function(bar, &module_id, &types), Err(vec![ErrorInner { pos: Some((Position::mock(), Position::mock())), message: "Identifier \"a\" is undefined".into() @@ -3750,7 +4596,7 @@ mod tests { Assignee::Identifier("a").mock(), Assignee::Identifier("b").mock(), ], - Expression::FunctionCall("foo", vec![]).mock(), + Expression::FunctionCall("foo", None, vec![]).mock(), ) .mock(), Statement::Return( @@ -3775,10 +4621,14 @@ mod tests { typed_absy::Variable::field_element("b").into(), ], TypedExpressionList::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new().outputs(vec![Type::FieldElement, Type::FieldElement]), + DeclarationFunctionKey::with_location(MODULE_ID, "foo").signature( + DeclarationSignature::new().outputs(vec![ + DeclarationType::FieldElement, + DeclarationType::FieldElement, + ]), ), vec![], + vec![], vec![Type::FieldElement, Type::FieldElement], ), ), @@ -3789,12 +4639,13 @@ mod tests { .into()]), ]; - let foo = FunctionKey { + let foo = DeclarationFunctionKey { + module: MODULE_ID.into(), id: "foo", - signature: Signature { - inputs: vec![], - outputs: vec![Type::FieldElement, Type::FieldElement], - }, + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![ + DeclarationType::FieldElement, + DeclarationType::FieldElement, + ]), }; let mut functions = HashSet::new(); @@ -3803,32 +4654,62 @@ mod tests { let bar = Function { arguments: vec![], statements: bar_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); let bar_checked = TypedFunction { arguments: vec![], statements: bar_statements_checked, - signature: Signature { - inputs: vec![], - outputs: vec![Type::FieldElement], - }, + signature: DeclarationSignature::new() + .inputs(vec![]) + .outputs(vec![DeclarationType::FieldElement]), }; let types = HashMap::new(); - let module_id = "".into(); - let mut checker = new_with_args(HashSet::new(), 0, functions); + let mut checker: Checker = new_with_args(HashSet::new(), 0, functions); assert_eq!( - checker.check_function(bar, &module_id, &types), + checker.check_function(bar, &MODULE_ID.into(), &types), Ok(bar_checked) ); } + #[test] + fn duplicate_argument_name() { + // def main(field a, bool a): + // return + + // should fail + + let mut f = function0(); + f.value.arguments = vec![ + absy::Parameter::private( + absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), + ) + .mock(), + absy::Parameter::private( + absy::Variable::new("a", UnresolvedType::Boolean.mock()).mock(), + ) + .mock(), + ]; + f.value.signature = UnresolvedSignature::new().inputs(vec![ + UnresolvedType::FieldElement.mock(), + UnresolvedType::Boolean.mock(), + ]); + + let mut checker: Checker = new_with_args(HashSet::new(), 0, HashSet::new()); + assert_eq!( + checker + .check_function(f, &"".into(), &HashMap::new()) + .unwrap_err()[0] + .message, + "Duplicate name in function definition: `a` was previously declared as an argument or a generic constant" + ); + } + #[test] fn duplicate_main_function() { // def main(a): @@ -3839,7 +4720,7 @@ mod tests { // should fail let main1_statements: Vec = vec![Statement::Return( ExpressionList { - expressions: vec![Expression::FieldConstant(BigUint::from(1u32)).mock()], + expressions: vec![Expression::IntConstant(1usize.into()).mock()], } .mock(), ) @@ -3853,7 +4734,7 @@ mod tests { let main2_statements: Vec = vec![Statement::Return( ExpressionList { - expressions: vec![Expression::FieldConstant(BigUint::from(1u32)).mock()], + expressions: vec![Expression::IntConstant(1usize.into()).mock()], } .mock(), ) @@ -3864,20 +4745,18 @@ mod tests { let main1 = Function { arguments: main1_arguments, statements: main1_statements, - signature: UnresolvedSignature { - inputs: vec![UnresolvedType::FieldElement.mock()], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![UnresolvedType::FieldElement.mock()]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); let main2 = Function { arguments: main2_arguments, statements: main2_statements, - signature: UnresolvedSignature { - inputs: vec![], - outputs: vec![UnresolvedType::FieldElement.mock()], - }, + signature: UnresolvedSignature::new() + .inputs(vec![]) + .outputs(vec![UnresolvedType::FieldElement.mock()]), } .mock(); @@ -3904,9 +4783,9 @@ mod tests { main: "main".into(), }; - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); assert_eq!( - checker.check_program::(program), + checker.check_program(program), Err(vec![Error { inner: ErrorInner { pos: None, @@ -3926,7 +4805,7 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), @@ -3963,7 +4842,7 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); let _: Result, Vec> = checker.check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), @@ -3992,11 +4871,12 @@ mod tests { mod structs { use super::*; + use crate::typed_absy::types::StructMember; /// solver function to create a module at location "" with a single symbol `Foo { foo: field }` fn create_module_with_foo( s: StructDefinition<'static>, - ) -> (Checker<'static>, State<'static, Bn128Field>) { + ) -> (Checker, State) { let module_id: PathBuf = "".into(); let module: Module = Module { @@ -4011,7 +4891,7 @@ mod tests { let mut state = State::::new(vec![(module_id.clone(), module)].into_iter().collect()); - let mut checker = Checker::new(); + let mut checker: Checker = Checker::new(); checker.check_module(&module_id, &mut state).unwrap(); @@ -4027,12 +4907,16 @@ mod tests { // an empty struct should be allowed to be defined let module_id = "".into(); let types = HashMap::new(); - let declaration = StructDefinition { fields: vec![] }.mock(); + let declaration: StructDefinitionNode = StructDefinition { fields: vec![] }.mock(); - let expected_type = Type::Struct(StructType::new("".into(), "Foo".into(), vec![])); + let expected_type = DeclarationType::Struct(DeclarationStructType::new( + "".into(), + "Foo".into(), + vec![], + )); assert_eq!( - Checker::new().check_struct_type_declaration( + Checker::::new().check_struct_type_declaration( "Foo".into(), declaration, &module_id, @@ -4047,7 +4931,7 @@ mod tests { // a valid struct should be allowed to be defined let module_id = "".into(); let types = HashMap::new(); - let declaration = StructDefinition { + let declaration: StructDefinitionNode = StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -4063,17 +4947,17 @@ mod tests { } .mock(); - let expected_type = Type::Struct(StructType::new( + let expected_type = DeclarationType::Struct(DeclarationStructType::new( "".into(), "Foo".into(), vec![ - StructMember::new("foo".into(), Type::FieldElement), - StructMember::new("bar".into(), Type::Boolean), + DeclarationStructMember::new("foo".into(), DeclarationType::FieldElement), + DeclarationStructMember::new("bar".into(), DeclarationType::Boolean), ], )); assert_eq!( - Checker::new().check_struct_type_declaration( + Checker::::new().check_struct_type_declaration( "Foo".into(), declaration, &module_id, @@ -4083,67 +4967,13 @@ mod tests { ); } - #[test] - fn preserve_order() { - // two structs with inverted members are not equal - let module_id = "".into(); - let types = HashMap::new(); - - let declaration0 = StructDefinition { - fields: vec![ - StructDefinitionField { - id: "foo", - ty: UnresolvedType::FieldElement.mock(), - } - .mock(), - StructDefinitionField { - id: "bar", - ty: UnresolvedType::Boolean.mock(), - } - .mock(), - ], - } - .mock(); - - let declaration1 = StructDefinition { - fields: vec![ - StructDefinitionField { - id: "bar", - ty: UnresolvedType::Boolean.mock(), - } - .mock(), - StructDefinitionField { - id: "foo", - ty: UnresolvedType::FieldElement.mock(), - } - .mock(), - ], - } - .mock(); - - assert_ne!( - Checker::new().check_struct_type_declaration( - "Foo".into(), - declaration0, - &module_id, - &types - ), - Checker::new().check_struct_type_declaration( - "Foo".into(), - declaration1, - &module_id, - &types - ) - ); - } - #[test] fn duplicate_member_def() { // definition of a struct with a duplicate member should be rejected let module_id = "".into(); let types = HashMap::new(); - let declaration = StructDefinition { + let declaration: StructDefinitionNode = StructDefinition { fields: vec![ StructDefinitionField { id: "foo", @@ -4160,7 +4990,7 @@ mod tests { .mock(); assert_eq!( - Checker::new() + Checker::::new() .check_struct_type_declaration( "Foo".into(), declaration, @@ -4228,15 +5058,18 @@ mod tests { .unwrap() .get(&"Bar".to_string()) .unwrap(), - &Type::Struct(StructType::new( + &DeclarationType::Struct(DeclarationStructType::new( module_id.clone(), "Bar".into(), - vec![StructMember::new( + vec![DeclarationStructMember::new( "foo".into(), - Type::Struct(StructType::new( + DeclarationType::Struct(DeclarationStructType::new( module_id, "Foo".into(), - vec![StructMember::new("foo".into(), Type::FieldElement)] + vec![DeclarationStructMember::new( + "foo".into(), + DeclarationType::FieldElement + )] )) )] )) @@ -4373,7 +5206,7 @@ mod tests { // an undefined type cannot be checked // Bar - let (checker, state) = create_module_with_foo(StructDefinition { + let (mut checker, state) = create_module_with_foo(StructDefinition { fields: vec![StructDefinitionField { id: "foo", ty: UnresolvedType::FieldElement.mock(), @@ -4384,7 +5217,7 @@ mod tests { assert_eq!( checker.check_type( UnresolvedType::User("Foo".into()).mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ), Ok(Type::Struct(StructType::new( @@ -4398,7 +5231,7 @@ mod tests { checker .check_type( UnresolvedType::User("Bar".into()).mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ) .unwrap_err() @@ -4406,121 +5239,6 @@ mod tests { "Undefined type Bar" ); } - - #[test] - fn parameter() { - // a defined type can be used as parameter - - // an undefined type cannot be used as parameter - - let (checker, state) = create_module_with_foo(StructDefinition { - fields: vec![StructDefinitionField { - id: "foo", - ty: UnresolvedType::FieldElement.mock(), - } - .mock()], - }); - - assert_eq!( - checker.check_parameter( - absy::Parameter { - id: - absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),) - .mock(), - private: true, - } - .mock(), - &PathBuf::from(MODULE_ID).into(), - &state.types, - ), - Ok(Parameter { - id: Variable::with_id_and_type( - "a", - Type::Struct(StructType::new( - "".into(), - "Foo".into(), - vec![StructMember::new("foo".into(), Type::FieldElement)] - )) - ), - private: true - }) - ); - - assert_eq!( - checker - .check_parameter( - absy::Parameter { - id: absy::Variable::new( - "a", - UnresolvedType::User("Bar".into()).mock(), - ) - .mock(), - private: true, - } - .mock(), - &PathBuf::from(MODULE_ID).into(), - &state.types, - ) - .unwrap_err()[0] - .message, - "Undefined type Bar" - ); - } - - #[test] - fn variable_declaration() { - // a defined type can be used in a variable declaration - - // an undefined type cannot be used in a variable declaration - - let (mut checker, state) = create_module_with_foo(StructDefinition { - fields: vec![StructDefinitionField { - id: "foo", - ty: UnresolvedType::FieldElement.mock(), - } - .mock()], - }); - - assert_eq!( - checker.check_statement::( - Statement::Declaration( - absy::Variable::new("a", UnresolvedType::User("Foo".into()).mock(),) - .mock() - ) - .mock(), - &PathBuf::from(MODULE_ID).into(), - &state.types, - ), - Ok(TypedStatement::Declaration(Variable::with_id_and_type( - "a", - Type::Struct(StructType::new( - "".into(), - "Foo".into(), - vec![StructMember::new("foo".into(), Type::FieldElement)] - )) - ))) - ); - - assert_eq!( - checker - .check_parameter( - absy::Parameter { - id: absy::Variable::new( - "a", - UnresolvedType::User("Bar".into()).mock(), - ) - .mock(), - private: true, - } - .mock(), - &PathBuf::from(MODULE_ID).into(), - &state.types, - ) - .unwrap_err()[0] - .message, - "Undefined type Bar" - ); - } } /// tests about accessing members @@ -4543,20 +5261,17 @@ mod tests { }); assert_eq!( - checker.check_expression::( + checker.check_expression( Expression::Member( box Expression::InlineStruct( "Foo".into(), - vec![( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - )] + vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), "foo".into() ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ), Ok(FieldElementExpression::Member( @@ -4592,20 +5307,17 @@ mod tests { assert_eq!( checker - .check_expression::( + .check_expression( Expression::Member( box Expression::InlineStruct( "Foo".into(), - vec![( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - )] + vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), "bar".into() ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ) .unwrap_err() @@ -4633,16 +5345,13 @@ mod tests { assert_eq!( checker - .check_expression::( + .check_expression( Expression::InlineStruct( "Bar".into(), - vec![( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - )] + vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ) .unwrap_err() @@ -4674,19 +5383,16 @@ mod tests { }); assert_eq!( - checker.check_expression::( + checker.check_expression( Expression::InlineStruct( "Foo".into(), vec![ - ( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - ), + ("foo", Expression::IntConstant(42usize.into()).mock()), ("bar", Expression::BooleanConstant(true).mock()) ] ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ), Ok(StructExpressionInner::Value(vec![ @@ -4728,19 +5434,16 @@ mod tests { }); assert_eq!( - checker.check_expression::( + checker.check_expression( Expression::InlineStruct( "Foo".into(), vec![ ("bar", Expression::BooleanConstant(true).mock()), - ( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - ) + ("foo", Expression::IntConstant(42usize.into()).mock()) ] ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ), Ok(StructExpressionInner::Value(vec![ @@ -4783,16 +5486,13 @@ mod tests { assert_eq!( checker - .check_expression::( + .check_expression( Expression::InlineStruct( "Foo".into(), - vec![( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - )] + vec![("foo", Expression::IntConstant(42usize.into()).mock())] ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ) .unwrap_err() @@ -4827,7 +5527,7 @@ mod tests { assert_eq!( checker - .check_expression::( + .check_expression( Expression::InlineStruct( "Foo".into(), vec![( @@ -4835,11 +5535,11 @@ mod tests { Expression::BooleanConstant(true).mock() ),( "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() + Expression::IntConstant(42usize.into()).mock() )] ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ).unwrap_err() .message, @@ -4848,35 +5548,137 @@ mod tests { assert_eq!( checker - .check_expression::( + .check_expression( Expression::InlineStruct( "Foo".into(), vec![ - ( - "bar", - Expression::FieldConstant(BigUint::from(42u32)).mock() - ), - ( - "foo", - Expression::FieldConstant(BigUint::from(42u32)).mock() - ) + ("bar", Expression::IntConstant(42usize.into()).mock()), + ("foo", Expression::IntConstant(42usize.into()).mock()) ] ) .mock(), - &PathBuf::from(MODULE_ID).into(), + &PathBuf::from(MODULE_ID), &state.types ) .unwrap_err() .message, - "Member bar of struct Foo has type bool, found 42 of type field" + "Member bar of struct Foo has type bool, found 42 of type {integer}" ); } } } + mod int_inference { + use super::*; + + #[test] + fn two_candidates() { + // def foo(field a) -> field: + // return 0 + + // def foo(u32 a) -> field: + // return 0 + + // def main() -> field: + // return foo(0) + + // should fail + + let mut foo_field = function0(); + + foo_field.value.arguments = vec![absy::Parameter::private( + absy::Variable { + id: "a", + _type: UnresolvedType::FieldElement.mock(), + } + .mock(), + ) + .mock()]; + foo_field.value.statements = vec![Statement::Return( + ExpressionList { + expressions: vec![Expression::IntConstant(0usize.into()).mock()], + } + .mock(), + ) + .mock()]; + foo_field.value.signature = UnresolvedSignature::new() + .inputs(vec![UnresolvedType::FieldElement.mock()]) + .outputs(vec![UnresolvedType::FieldElement.mock()]); + + let mut foo_u32 = function0(); + + foo_u32.value.arguments = vec![absy::Parameter::private( + absy::Variable { + id: "a", + _type: UnresolvedType::Uint(32).mock(), + } + .mock(), + ) + .mock()]; + foo_u32.value.statements = vec![Statement::Return( + ExpressionList { + expressions: vec![Expression::IntConstant(0usize.into()).mock()], + } + .mock(), + ) + .mock()]; + foo_u32.value.signature = UnresolvedSignature::new() + .inputs(vec![UnresolvedType::Uint(32).mock()]) + .outputs(vec![UnresolvedType::FieldElement.mock()]); + + let mut main = function0(); + + main.value.statements = vec![Statement::Return( + ExpressionList { + expressions: vec![Expression::FunctionCall( + "foo", + None, + vec![Expression::IntConstant(0usize.into()).mock()], + ) + .mock()], + } + .mock(), + ) + .mock()]; + main.value.signature = + UnresolvedSignature::new().outputs(vec![UnresolvedType::FieldElement.mock()]); + + let m = Module::with_symbols(vec![ + absy::SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(foo_field), + } + .mock(), + absy::SymbolDeclaration { + id: "foo", + symbol: Symbol::HereFunction(foo_u32), + } + .mock(), + absy::SymbolDeclaration { + id: "main", + symbol: Symbol::HereFunction(main), + } + .mock(), + ]); + + let p = Program { + main: "".into(), + modules: vec![("".into(), m)].into_iter().collect(), + }; + + let errors = Checker::::new().check_program(p).unwrap_err(); + + assert_eq!(errors.len(), 1); + + assert_eq!( + errors[0].inner.message, + "Ambiguous call to function foo, 2 candidates were found. Please be more explicit." + ); + } + } + mod assignee { use super::*; - use num_bigint::BigUint; #[test] fn identifier() { @@ -4885,9 +5687,11 @@ mod tests { let types = HashMap::new(); let module_id = "".into(); - let mut checker: Checker = Checker::new(); + + let mut checker: Checker = Checker::new(); + checker - .check_statement::( + .check_statement( Statement::Declaration( absy::Variable::new("a", UnresolvedType::FieldElement.mock()).mock(), ) @@ -4898,7 +5702,7 @@ mod tests { .unwrap(); assert_eq!( - checker.check_assignee::(a, &module_id, &types), + checker.check_assignee(a, &module_id, &types), Ok(TypedAssignee::Identifier( typed_absy::Variable::field_element("a") )) @@ -4911,70 +5715,22 @@ mod tests { // a[2] = 42 let a = Assignee::Select( box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression( - Expression::FieldConstant(BigUint::from(2u32)).mock(), - ), + box RangeOrExpression::Expression(Expression::IntConstant(2usize.into()).mock()), ) .mock(); let types = HashMap::new(); let module_id = "".into(); - let mut checker: Checker = Checker::new(); + let mut checker: Checker = Checker::new(); checker - .check_statement::( - Statement::Declaration( - absy::Variable::new( - "a", - UnresolvedType::array(UnresolvedType::FieldElement.mock(), 33).mock(), - ) - .mock(), - ) - .mock(), - &module_id, - &types, - ) - .unwrap(); - - assert_eq!( - checker.check_assignee::(a, &module_id, &types), - Ok(TypedAssignee::Select( - box TypedAssignee::Identifier(typed_absy::Variable::field_array("a", 33)), - box FieldElementExpression::Number(Bn128Field::from(2u32)).into() - )) - ); - } - - #[test] - fn array_of_array_element() { - // field[33][42] a - // a[1][2] - let a = Assignee::Select( - box Assignee::Select( - box Assignee::Identifier("a").mock(), - box RangeOrExpression::Expression( - Expression::FieldConstant(BigUint::from(1u32)).mock(), - ), - ) - .mock(), - box RangeOrExpression::Expression( - Expression::FieldConstant(BigUint::from(2u32)).mock(), - ), - ) - .mock(); - - let types = HashMap::new(); - let module_id = "".into(); - let mut checker: Checker = Checker::new(); - checker - .check_statement::( + .check_statement( Statement::Declaration( absy::Variable::new( "a", UnresolvedType::array( - UnresolvedType::array(UnresolvedType::FieldElement.mock(), 33) - .mock(), - 42, + UnresolvedType::FieldElement.mock(), + Expression::IntConstant(33usize.into()).mock(), ) .mock(), ) @@ -4987,17 +5743,72 @@ mod tests { .unwrap(); assert_eq!( - checker.check_assignee::(a, &module_id, &types), + checker.check_assignee(a, &module_id, &types), + Ok(TypedAssignee::Select( + box TypedAssignee::Identifier(typed_absy::Variable::field_array( + "a", + 33u32.into() + )), + box 2u32.into() + )) + ); + } + + #[test] + fn array_of_array_element() { + // field[33][42] a + // a[1][2] + let a: AssigneeNode = Assignee::Select( + box Assignee::Select( + box Assignee::Identifier("a").mock(), + box RangeOrExpression::Expression( + Expression::IntConstant(1usize.into()).mock(), + ), + ) + .mock(), + box RangeOrExpression::Expression(Expression::IntConstant(2usize.into()).mock()), + ) + .mock(); + + let types = HashMap::new(); + let module_id = "".into(); + + let mut checker: Checker = Checker::new(); + checker + .check_statement( + Statement::Declaration( + absy::Variable::new( + "a", + UnresolvedType::array( + UnresolvedType::array( + UnresolvedType::FieldElement.mock(), + Expression::IntConstant(33usize.into()).mock(), + ) + .mock(), + Expression::IntConstant(42usize.into()).mock(), + ) + .mock(), + ) + .mock(), + ) + .mock(), + &module_id, + &types, + ) + .unwrap(); + + assert_eq!( + checker.check_assignee(a, &module_id, &types), Ok(TypedAssignee::Select( box TypedAssignee::Select( box TypedAssignee::Identifier(typed_absy::Variable::array( "a", - Type::array(Type::FieldElement, 33), - 42 + Type::array((Type::FieldElement, 33u32)), + 42u32, )), - box FieldElementExpression::Number(Bn128Field::from(1u32)).into() + box 1u32.into() ), - box FieldElementExpression::Number(Bn128Field::from(2u32)).into() + box 2u32.into() )) ); } diff --git a/zokrates_core/src/solvers/mod.rs b/zokrates_core/src/solvers/mod.rs index 7f5b9b2b..d6bc0c0c 100644 --- a/zokrates_core/src/solvers/mod.rs +++ b/zokrates_core/src/solvers/mod.rs @@ -1,10 +1,5 @@ -#[cfg(feature = "bellman")] -use pairing_ce::bn256::Bn256; use serde::{Deserialize, Serialize}; use std::fmt; -#[cfg(feature = "bellman")] -use zokrates_embed::generate_sha256_round_witness; -use zokrates_field::{Bn128Field, Field}; #[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)] pub enum Solver { @@ -48,168 +43,3 @@ impl Solver { Solver::Bits(width) } } - -pub trait Executable { - fn execute(&self, inputs: &Vec) -> Result, String>; -} - -impl Executable for Solver { - fn execute(&self, inputs: &Vec) -> Result, String> { - let (expected_input_count, expected_output_count) = self.get_signature(); - assert_eq!(inputs.len(), expected_input_count); - - let res = match self { - Solver::ConditionEq => match inputs[0].is_zero() { - true => vec![T::zero(), T::one()], - false => vec![ - T::one(), - T::one().checked_div(&inputs[0]).unwrap_or(T::one()), - ], - }, - Solver::Bits(bit_width) => { - let mut num = inputs[0].clone(); - let mut res = vec![]; - - for i in (0..*bit_width).rev() { - if T::from(2).pow(i) <= num { - num = num - T::from(2).pow(i); - res.push(T::one()); - } else { - res.push(T::zero()); - } - } - res - } - Solver::Xor => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); - - vec![x.clone() + y.clone() - T::from(2) * x * y] - } - Solver::Or => { - let x = inputs[0].clone(); - let y = inputs[1].clone(); - - vec![x.clone() + y.clone() - x * y] - } - // res = b * c - (2b * c - b - c) * (a) - Solver::ShaAndXorAndXorAnd => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![b.clone() * c.clone() - (T::from(2) * b.clone() * c.clone() - b - c) * a] - } - // res = a(b - c) + c - Solver::ShaCh => { - let a = inputs[0].clone(); - let b = inputs[1].clone(); - let c = inputs[2].clone(); - vec![a * (b - c.clone()) + c] - } - Solver::Div => vec![inputs[0] - .clone() - .checked_div(&inputs[1]) - .unwrap_or(T::one())], - Solver::EuclideanDiv => { - use num::CheckedDiv; - - let n = inputs[0].clone().to_biguint(); - let d = inputs[1].clone().to_biguint(); - - let q = n.checked_div(&d).unwrap_or(0u32.into()); - let r = n - d * &q; - vec![T::try_from(q).unwrap(), T::try_from(r).unwrap()] - } - #[cfg(feature = "bellman")] - Solver::Sha256Round => { - assert_eq!(T::id(), Bn128Field::id()); - let i = &inputs[0..512]; - let h = &inputs[512..]; - let to_fr = |x: &T| { - use pairing_ce::ff::{PrimeField, ScalarEngine}; - let s = x.to_dec_string(); - ::Fr::from_str(&s).unwrap() - }; - let i: Vec<_> = i.iter().map(|x| to_fr(x)).collect(); - let h: Vec<_> = h.iter().map(|x| to_fr(x)).collect(); - assert_eq!(h.len(), 256); - generate_sha256_round_witness::(&i, &h) - .into_iter() - .map(|x| { - use bellman_ce::pairing::ff::{PrimeField, PrimeFieldRepr}; - let mut res: Vec = vec![]; - x.into_repr().write_le(&mut res).unwrap(); - T::from_byte_vector(res) - }) - .collect() - } - }; - - assert_eq!(res.len(), expected_output_count); - - Ok(res) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use zokrates_field::Bn128Field; - - mod eq_condition { - - // Wanted: (Y = (X != 0) ? 1 : 0) - // # Y = if X == 0 then 0 else 1 fi - // # M = if X == 0 then 1 else 1/X fi - - use super::*; - - #[test] - fn execute() { - let cond_eq = Solver::ConditionEq; - let inputs = vec![0]; - let r = cond_eq - .execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect()) - .unwrap(); - let res: Vec = vec![0, 1].iter().map(|&i| Bn128Field::from(i)).collect(); - assert_eq!(r, &res[..]); - } - - #[test] - fn execute_non_eq() { - let cond_eq = Solver::ConditionEq; - let inputs = vec![1]; - let r = cond_eq - .execute(&inputs.iter().map(|&i| Bn128Field::from(i)).collect()) - .unwrap(); - let res: Vec = vec![1, 1].iter().map(|&i| Bn128Field::from(i)).collect(); - assert_eq!(r, &res[..]); - } - } - - #[test] - fn bits_of_one() { - let bits = Solver::Bits(Bn128Field::get_required_bits()); - let inputs = vec![Bn128Field::from(1)]; - let res = bits.execute(&inputs).unwrap(); - assert_eq!(res[253], Bn128Field::from(1)); - for i in 0..253 { - assert_eq!(res[i], Bn128Field::from(0)); - } - } - - #[test] - fn bits_of_42() { - let bits = Solver::Bits(Bn128Field::get_required_bits()); - let inputs = vec![Bn128Field::from(42)]; - let res = bits.execute(&inputs).unwrap(); - - assert_eq!(res[253], Bn128Field::from(0)); - assert_eq!(res[252], Bn128Field::from(1)); - assert_eq!(res[251], Bn128Field::from(0)); - assert_eq!(res[250], Bn128Field::from(1)); - assert_eq!(res[249], Bn128Field::from(0)); - assert_eq!(res[248], Bn128Field::from(1)); - assert_eq!(res[247], Bn128Field::from(0)); - } -} diff --git a/zokrates_core/src/static_analysis/bounds_checker.rs b/zokrates_core/src/static_analysis/bounds_checker.rs new file mode 100644 index 00000000..de43d5f0 --- /dev/null +++ b/zokrates_core/src/static_analysis/bounds_checker.rs @@ -0,0 +1,134 @@ +use crate::typed_absy::result_folder::*; +use crate::typed_absy::*; +use zokrates_field::Field; + +pub struct BoundsChecker; + +pub type Error = String; + +impl BoundsChecker { + pub fn check(p: TypedProgram) -> Result, Error> { + BoundsChecker.fold_program(p) + } + + pub fn check_select<'ast, T: Field, U: Select<'ast, T>>( + &mut self, + array: ArrayExpression<'ast, T>, + index: UExpression<'ast, T>, + ) -> Result { + let array = self.fold_array_expression(array)?; + let index = self.fold_uint_expression(index)?; + + match (array.get_array_type().size.as_inner(), index.as_inner()) { + (UExpressionInner::Value(size), UExpressionInner::Value(index)) => { + if index >= size { + return Err(format!( + "Out of bounds access: {}[{}] but {} is of size {}", + array, index, array, size + )); + } + } + _ => unreachable!(), + }; + + Ok(U::select(array, index)) + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker { + type Error = Error; + + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + ArrayExpressionInner::Select(box array, box index) => self + .check_select::<_, ArrayExpression<_>>(array, index) + .map(|a| a.into_inner()), + ArrayExpressionInner::Slice(box array, box from, box to) => { + let array = self.fold_array_expression(array)?; + let from = self.fold_uint_expression(from)?; + let to = self.fold_uint_expression(to)?; + + match ( + array.get_array_type().size.as_inner(), + from.as_inner(), + to.as_inner(), + ) { + ( + UExpressionInner::Value(size), + UExpressionInner::Value(from), + UExpressionInner::Value(to), + ) => { + if from > to { + return Err(format!( + "Slice is created from an invalid range {}..{}", + from, to + )); + } + + if from > size { + return Err(format!("Lower bound {} of slice {}[{}..{}] is out of bounds for array of size {}", from, array, from, to, size)); + } + + if to > size { + return Err(format!("Upper bound {} of slice {}[{}..{}] is out of bounds for array of size {}", to, array, from, to, size)); + } + } + _ => unreachable!(), + }; + + Ok(ArrayExpressionInner::Slice(box array, box from, box to)) + } + e => fold_array_expression_inner(self, ty, e), + } + } + + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + StructExpressionInner::Select(box array, box index) => self + .check_select::<_, StructExpression<_>>(array, index) + .map(|a| a.into_inner()), + e => fold_struct_expression_inner(self, ty, e), + } + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + FieldElementExpression::Select(box array, box index) => self.check_select(array, index), + e => fold_field_expression(self, e), + } + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + BooleanExpression::Select(box array, box index) => self.check_select(array, index), + e => fold_boolean_expression(self, e), + } + } + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + UExpressionInner::Select(box array, box index) => self + .check_select::<_, UExpression<_>>(array, index) + .map(|a| a.into_inner()), + e => fold_uint_expression_inner(self, bitwidth, e), + } + } +} diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 9f23767a..d46a3f2c 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -1,31 +1,34 @@ use crate::typed_absy; -use crate::typed_absy::types::{StructType, UBitwidth}; +use crate::typed_absy::types::UBitwidth; use crate::zir; use std::marker::PhantomData; use zokrates_field::Field; +use std::convert::{TryFrom, TryInto}; + pub struct Flattener { phantom: PhantomData, } -fn flatten_identifier_rec<'a>( - id: zir::SourceIdentifier<'a>, - ty: &typed_absy::Type, -) -> Vec> { +fn flatten_identifier_rec<'ast>( + id: zir::SourceIdentifier<'ast>, + ty: &typed_absy::types::ConcreteType, +) -> Vec> { match ty { - typed_absy::Type::FieldElement => vec![zir::Variable { + typed_absy::ConcreteType::Int => unreachable!(), + typed_absy::ConcreteType::FieldElement => vec![zir::Variable { id: zir::Identifier::Source(id), _type: zir::Type::FieldElement, }], - typed_absy::Type::Boolean => vec![zir::Variable { + typed_absy::types::ConcreteType::Boolean => vec![zir::Variable { id: zir::Identifier::Source(id), _type: zir::Type::Boolean, }], - typed_absy::Type::Uint(bitwidth) => vec![zir::Variable { + typed_absy::types::ConcreteType::Uint(bitwidth) => vec![zir::Variable { id: zir::Identifier::Source(id), _type: zir::Type::uint(bitwidth.to_usize()), }], - typed_absy::Type::Array(array_type) => (0..array_type.size) + typed_absy::types::ConcreteType::Array(array_type) => (0..array_type.size) .flat_map(|i| { flatten_identifier_rec( zir::SourceIdentifier::Select(box id.clone(), i), @@ -33,7 +36,7 @@ fn flatten_identifier_rec<'a>( ) }) .collect(), - typed_absy::Type::Struct(members) => members + typed_absy::types::ConcreteType::Struct(members) => members .iter() .flat_map(|struct_member| { flatten_identifier_rec( @@ -57,17 +60,6 @@ impl<'ast, T: Field> Flattener { fold_program(self, p) } - fn fold_module(&mut self, p: typed_absy::TypedModule<'ast, T>) -> zir::ZirModule<'ast, T> { - fold_module(self, p) - } - - fn fold_function_symbol( - &mut self, - s: typed_absy::TypedFunctionSymbol<'ast, T>, - ) -> zir::ZirFunctionSymbol<'ast, T> { - fold_function_symbol(self, s) - } - fn fold_function( &mut self, f: typed_absy::TypedFunction<'ast, T>, @@ -75,9 +67,12 @@ impl<'ast, T: Field> Flattener { fold_function(self, f) } - fn fold_parameter(&mut self, p: typed_absy::Parameter<'ast>) -> Vec> { + fn fold_declaration_parameter( + &mut self, + p: typed_absy::DeclarationParameter<'ast>, + ) -> Vec> { let private = p.private; - self.fold_variable(p.id) + self.fold_variable(p.id.try_into().unwrap()) .into_iter() .map(|v| zir::Parameter { id: v, private }) .collect() @@ -87,10 +82,12 @@ impl<'ast, T: Field> Flattener { zir::SourceIdentifier::Basic(n) } - fn fold_variable(&mut self, v: typed_absy::Variable<'ast>) -> Vec> { + fn fold_variable(&mut self, v: typed_absy::Variable<'ast, T>) -> Vec> { let id = self.fold_name(v.id.clone()); let ty = v.get_type(); + let ty = typed_absy::types::ConcreteType::try_from(ty).unwrap(); + flatten_identifier_rec(id, &ty) } @@ -102,36 +99,34 @@ impl<'ast, T: Field> Flattener { typed_absy::TypedAssignee::Identifier(v) => self.fold_variable(v), typed_absy::TypedAssignee::Select(box a, box i) => { use typed_absy::Typed; - let count = match a.get_type() { - typed_absy::Type::Array(array_ty) => array_ty.ty.get_primitive_count(), + let count = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap() { + typed_absy::ConcreteType::Array(array_ty) => array_ty.ty.get_primitive_count(), _ => unreachable!(), }; let a = self.fold_assignee(a); - match i { - typed_absy::FieldElementExpression::Number(n) => { - let index = n.to_dec_string().parse::().unwrap(); - a[index * count..(index + 1) * count].to_vec() + match i.as_inner() { + typed_absy::UExpressionInner::Value(index) => { + a[*index as usize * count..(*index as usize + 1) * count].to_vec() } - i => unreachable!("index {} not allowed, should be a constant", i), + i => unreachable!("index {:?} not allowed, should be a constant", i), } } typed_absy::TypedAssignee::Member(box a, m) => { use typed_absy::Typed; - let (offset, size) = match a.get_type() { - typed_absy::Type::Struct(struct_type) => { - struct_type - .members - .iter() - .fold((0, None), |(offset, size), member| match size { - Some(_) => (offset, size), - None => match m == member.id { - true => (offset, Some(member.ty.get_primitive_count())), - false => (offset + member.ty.get_primitive_count(), None), - }, - }) - } + let (offset, size) = match typed_absy::ConcreteType::try_from(a.get_type()).unwrap() + { + typed_absy::ConcreteType::Struct(struct_type) => struct_type + .members + .iter() + .fold((0, None), |(offset, size), member| match size { + Some(_) => (offset, size), + None => match m == member.id { + true => (offset, Some(member.ty.get_primitive_count())), + false => (offset + member.ty.get_primitive_count(), None), + }, + }), _ => unreachable!(), }; @@ -151,6 +146,16 @@ impl<'ast, T: Field> Flattener { fold_statement(self, s) } + fn fold_expression_or_spread( + &mut self, + e: typed_absy::TypedExpressionOrSpread<'ast, T>, + ) -> Vec> { + match e { + typed_absy::TypedExpressionOrSpread::Expression(e) => self.fold_expression(e), + typed_absy::TypedExpressionOrSpread::Spread(s) => self.fold_array_expression(s.array), + } + } + fn fold_expression( &mut self, e: typed_absy::TypedExpression<'ast, T>, @@ -161,8 +166,9 @@ impl<'ast, T: Field> Flattener { } typed_absy::TypedExpression::Boolean(e) => vec![self.fold_boolean_expression(e).into()], typed_absy::TypedExpression::Uint(e) => vec![self.fold_uint_expression(e).into()], - typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e).into(), - typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e).into(), + typed_absy::TypedExpression::Array(e) => self.fold_array_expression(e), + typed_absy::TypedExpression::Struct(e) => self.fold_struct_expression(e), + typed_absy::TypedExpression::Int(_) => unreachable!(), } } @@ -185,26 +191,20 @@ impl<'ast, T: Field> Flattener { es: typed_absy::TypedExpressionList<'ast, T>, ) -> zir::ZirExpressionList<'ast, T> { match es { - typed_absy::TypedExpressionList::FunctionCall(id, arguments, _) => { - zir::ZirExpressionList::FunctionCall( - self.fold_function_key(id), + typed_absy::TypedExpressionList::EmbedCall(embed, generics, arguments, _) => { + zir::ZirExpressionList::EmbedCall( + embed, + generics, arguments .into_iter() .flat_map(|a| self.fold_expression(a)) .collect(), - vec![], ) } + _ => unreachable!("should have been inlined"), } } - fn fold_function_key( - &mut self, - k: typed_absy::types::FunctionKey<'ast>, - ) -> zir::types::FunctionKey<'ast> { - k.into() - } - fn fold_field_expression( &mut self, e: typed_absy::FieldElementExpression<'ast, T>, @@ -234,7 +234,7 @@ impl<'ast, T: Field> Flattener { fn fold_array_expression_inner( &mut self, - ty: &typed_absy::Type, + ty: &typed_absy::types::ConcreteType, size: usize, e: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { @@ -242,26 +242,13 @@ impl<'ast, T: Field> Flattener { } fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &typed_absy::types::ConcreteStructType, e: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { fold_struct_expression_inner(self, ty, e) } } -pub fn fold_module<'ast, T: Field>( - f: &mut Flattener, - p: typed_absy::TypedModule<'ast, T>, -) -> zir::ZirModule<'ast, T> { - zir::ZirModule { - functions: p - .functions - .into_iter() - .map(|(key, fun)| (f.fold_function_key(key), f.fold_function_symbol(fun))) - .collect(), - } -} - pub fn fold_statement<'ast, T: Field>( f: &mut Flattener, s: typed_absy::TypedStatement<'ast, T>, @@ -284,9 +271,7 @@ pub fn fold_statement<'ast, T: Field>( } typed_absy::TypedStatement::Declaration(v) => { let v = f.fold_variable(v); - v.into_iter() - .map(|v| zir::ZirStatement::Declaration(v)) - .collect() + v.into_iter().map(zir::ZirStatement::Declaration).collect() } typed_absy::TypedStatement::Assertion(e) => { let e = f.fold_boolean_expression(e); @@ -302,19 +287,23 @@ pub fn fold_statement<'ast, T: Field>( f.fold_expression_list(elist), )] } + typed_absy::TypedStatement::PushCallLog(..) => vec![], + typed_absy::TypedStatement::PopCallLog => vec![], } } pub fn fold_array_expression_inner<'ast, T: Field>( f: &mut Flattener, - t: &typed_absy::Type, + ty: &typed_absy::types::ConcreteType, size: usize, - e: typed_absy::ArrayExpressionInner<'ast, T>, + array: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { - match e { + match array { typed_absy::ArrayExpressionInner::Identifier(id) => { - let variables = - flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::array(t.clone(), size)); + let variables = flatten_identifier_rec( + f.fold_name(id), + &typed_absy::types::ConcreteType::array((ty.clone(), size)), + ); variables .into_iter() .map(|v| match v._type { @@ -326,10 +315,16 @@ pub fn fold_array_expression_inner<'ast, T: Field>( }) .collect() } - typed_absy::ArrayExpressionInner::Value(exprs) => exprs - .into_iter() - .flat_map(|e| f.fold_expression(e)) - .collect(), + typed_absy::ArrayExpressionInner::Value(exprs) => { + let exprs: Vec<_> = exprs + .into_iter() + .flat_map(|e| f.fold_expression_or_spread(e)) + .collect(); + + assert_eq!(exprs.len(), size * ty.get_primitive_count()); + + exprs + } typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(), typed_absy::ArrayExpressionInner::IfElse( box condition, @@ -369,40 +364,74 @@ pub fn fold_array_expression_inner<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty.clone()) + .unwrap() + .get_primitive_count() + }) .sum(); // we also need the size of this member - let size = t.get_primitive_count() * size; + let size = ty.get_primitive_count() * size; s[offset..offset + size].to_vec() } typed_absy::ArrayExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - match index { - zir::FieldElementExpression::Number(i) => { - let size = t.get_primitive_count() * size; - let start = i.to_dec_string().parse::().unwrap() * size; + match index.into_inner() { + zir::UExpressionInner::Value(i) => { + let size = ty.clone().get_primitive_count() * size; + let start = i as usize * size; let end = start + size; array[start..end].to_vec() } _ => unreachable!(), } } + typed_absy::ArrayExpressionInner::Slice(box array, box from, box to) => { + let array = f.fold_array_expression(array); + let from = f.fold_uint_expression(from); + let to = f.fold_uint_expression(to); + + match (from.into_inner(), to.into_inner()) { + (zir::UExpressionInner::Value(from), zir::UExpressionInner::Value(to)) => { + assert_eq!(size, to.saturating_sub(from) as usize); + + let element_size = ty.get_primitive_count(); + let start = from as usize * element_size; + let end = to as usize * element_size; + array[start..end].to_vec() + } + _ => unreachable!(), + } + } + typed_absy::ArrayExpressionInner::Repeat(box e, box count) => { + let e = f.fold_expression(e); + let count = f.fold_uint_expression(count); + + match count.into_inner() { + zir::UExpressionInner::Value(count) => { + vec![e; count as usize].into_iter().flatten().collect() + } + _ => unreachable!(), + } + } } } pub fn fold_struct_expression_inner<'ast, T: Field>( f: &mut Flattener, - t: &StructType, - e: typed_absy::StructExpressionInner<'ast, T>, + ty: &typed_absy::types::ConcreteStructType, + struc: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { - match e { + match struc { typed_absy::StructExpressionInner::Identifier(id) => { - let variables = - flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::struc(t.clone())); + let variables = flatten_identifier_rec( + f.fold_name(id), + &typed_absy::types::ConcreteType::struc(ty.clone()), + ); variables .into_iter() .map(|v| match v._type { @@ -457,13 +486,18 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty.clone()) + .unwrap() + .get_primitive_count() + }) .sum(); // we also need the size of this member - let size = t + let size = ty .iter() .find(|member| member.id == id) + .cloned() .unwrap() .ty .get_primitive_count(); @@ -472,15 +506,12 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( } typed_absy::StructExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - match index { - zir::FieldElementExpression::Number(i) => { - let size = t - .iter() - .map(|m| m.ty.get_primitive_count()) - .fold(0, |acc, current| acc + current); - let start = i.to_dec_string().parse::().unwrap() * size; + match index.into_inner() { + zir::UExpressionInner::Value(i) => { + let size: usize = ty.iter().map(|m| m.ty.get_primitive_count()).sum(); + let start = i as usize * size; let end = start + size; array[start..end].to_vec() } @@ -498,9 +529,12 @@ pub fn fold_field_expression<'ast, T: Field>( typed_absy::FieldElementExpression::Number(n) => zir::FieldElementExpression::Number(n), typed_absy::FieldElementExpression::Identifier(id) => { zir::FieldElementExpression::Identifier( - flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::FieldElement)[0] - .id - .clone(), + flatten_identifier_rec( + f.fold_name(id), + &typed_absy::types::ConcreteType::FieldElement, + )[0] + .id + .clone(), ) } typed_absy::FieldElementExpression::Add(box e1, box e2) => { @@ -525,7 +559,7 @@ pub fn fold_field_expression<'ast, T: Field>( } typed_absy::FieldElementExpression::Pow(box e1, box e2) => { let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); + let e2 = f.fold_uint_expression(e2); zir::FieldElementExpression::Pow(box e1, box e2) } typed_absy::FieldElementExpression::Neg(box e) => { @@ -556,26 +590,22 @@ pub fn fold_field_expression<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty.clone()) + .unwrap() + .get_primitive_count() + }) .sum(); - use std::convert::TryInto; - s[offset].clone().try_into().unwrap() } typed_absy::FieldElementExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - use std::convert::TryInto; - - match index { - zir::FieldElementExpression::Number(i) => array - [i.to_dec_string().parse::().unwrap()] - .clone() - .try_into() - .unwrap(), + match index.into_inner() { + zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(), _ => unreachable!(""), } } @@ -589,7 +619,7 @@ pub fn fold_boolean_expression<'ast, T: Field>( match e { typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v), typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier( - flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Boolean)[0] + flatten_identifier_rec(f.fold_name(id), &typed_absy::types::ConcreteType::Boolean)[0] .id .clone(), ), @@ -665,25 +695,45 @@ pub fn fold_boolean_expression<'ast, T: Field>( zir::BooleanExpression::UintEq(box e1, box e2) } - typed_absy::BooleanExpression::Lt(box e1, box e2) => { + typed_absy::BooleanExpression::FieldLt(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - zir::BooleanExpression::Lt(box e1, box e2) + zir::BooleanExpression::FieldLt(box e1, box e2) } - typed_absy::BooleanExpression::Le(box e1, box e2) => { + typed_absy::BooleanExpression::FieldLe(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - zir::BooleanExpression::Le(box e1, box e2) + zir::BooleanExpression::FieldLe(box e1, box e2) } - typed_absy::BooleanExpression::Gt(box e1, box e2) => { + typed_absy::BooleanExpression::FieldGt(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - zir::BooleanExpression::Gt(box e1, box e2) + zir::BooleanExpression::FieldGt(box e1, box e2) } - typed_absy::BooleanExpression::Ge(box e1, box e2) => { + typed_absy::BooleanExpression::FieldGe(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - zir::BooleanExpression::Ge(box e1, box e2) + zir::BooleanExpression::FieldGe(box e1, box e2) + } + typed_absy::BooleanExpression::UintLt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + zir::BooleanExpression::UintLt(box e1, box e2) + } + typed_absy::BooleanExpression::UintLe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + zir::BooleanExpression::UintLe(box e1, box e2) + } + typed_absy::BooleanExpression::UintGt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + zir::BooleanExpression::UintGt(box e1, box e2) + } + typed_absy::BooleanExpression::UintGe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + zir::BooleanExpression::UintGe(box e1, box e2) } typed_absy::BooleanExpression::Or(box e1, box e2) => { let e1 = f.fold_boolean_expression(e1); @@ -714,25 +764,21 @@ pub fn fold_boolean_expression<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty.clone()) + .unwrap() + .get_primitive_count() + }) .sum(); - use std::convert::TryInto; - s[offset].clone().try_into().unwrap() } typed_absy::BooleanExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - use std::convert::TryInto; - - match index { - zir::FieldElementExpression::Number(i) => array - [i.to_dec_string().parse::().unwrap()] - .clone() - .try_into() - .unwrap(), + match index.into_inner() { + zir::UExpressionInner::Value(i) => array[i as usize].clone().try_into().unwrap(), _ => unreachable!(), } } @@ -755,9 +801,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( match e { typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v), typed_absy::UExpressionInner::Identifier(id) => zir::UExpressionInner::Identifier( - flatten_identifier_rec(f.fold_name(id), &typed_absy::Type::Uint(bitwidth))[0] - .id - .clone(), + flatten_identifier_rec( + f.fold_name(id), + &typed_absy::types::ConcreteType::Uint(bitwidth), + )[0] + .id + .clone(), ), typed_absy::UExpressionInner::Add(box left, box right) => { let left = f.fold_uint_expression(left); @@ -771,6 +820,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( zir::UExpressionInner::Sub(box left, box right) } + typed_absy::UExpressionInner::FloorSub(..) => unreachable!(), typed_absy::UExpressionInner::Mult(box left, box right) => { let left = f.fold_uint_expression(left); let right = f.fold_uint_expression(right); @@ -827,11 +877,8 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( typed_absy::UExpressionInner::Neg(box e) => { let bitwidth = e.bitwidth(); - f.fold_uint_expression(typed_absy::UExpression::sub( - typed_absy::UExpressionInner::Value(0).annotate(bitwidth), - e, - )) - .into_inner() + f.fold_uint_expression(typed_absy::UExpressionInner::Value(0).annotate(bitwidth) - e) + .into_inner() } typed_absy::UExpressionInner::Pos(box e) => { @@ -844,16 +891,11 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( } typed_absy::UExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); - use std::convert::TryInto; - - match index { - zir::FieldElementExpression::Number(i) => { - let e: zir::UExpression<_> = array[i.to_dec_string().parse::().unwrap()] - .clone() - .try_into() - .unwrap(); + match index.into_inner() { + zir::UExpressionInner::Value(i) => { + let e: zir::UExpression<_> = array[i as usize].clone().try_into().unwrap(); e.into_inner() } _ => unreachable!(), @@ -867,11 +909,13 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( let offset: usize = members .iter() .take_while(|member| member.id != id) - .map(|member| member.ty.get_primitive_count()) + .map(|member| { + typed_absy::types::ConcreteType::try_from(*member.ty.clone()) + .unwrap() + .get_primitive_count() + }) .sum(); - use std::convert::TryInto; - let res: zir::UExpression<'ast, T> = s[offset].clone().try_into().unwrap(); res.into_inner() @@ -893,14 +937,18 @@ pub fn fold_function<'ast, T: Field>( arguments: fun .arguments .into_iter() - .flat_map(|a| f.fold_parameter(a)) + .flat_map(|a| f.fold_declaration_parameter(a)) .collect(), statements: fun .statements .into_iter() .flat_map(|s| f.fold_statement(s)) .collect(), - signature: fun.signature.into(), + signature: typed_absy::types::ConcreteSignature::try_from( + typed_absy::types::Signature::::try_from(fun.signature).unwrap(), + ) + .unwrap() + .into(), } } @@ -908,41 +956,46 @@ pub fn fold_array_expression<'ast, T: Field>( f: &mut Flattener, e: typed_absy::ArrayExpression<'ast, T>, ) -> Vec> { - f.fold_array_expression_inner(&e.inner_type().clone(), e.size(), e.into_inner()) + let size = match e.size().into_inner() { + typed_absy::UExpressionInner::Value(v) => v, + _ => unreachable!(), + } as usize; + f.fold_array_expression_inner( + &typed_absy::types::ConcreteType::try_from(e.inner_type().clone()).unwrap(), + size, + e.into_inner(), + ) } pub fn fold_struct_expression<'ast, T: Field>( f: &mut Flattener, e: typed_absy::StructExpression<'ast, T>, ) -> Vec> { - f.fold_struct_expression_inner(&e.ty().clone(), e.into_inner()) -} - -pub fn fold_function_symbol<'ast, T: Field>( - f: &mut Flattener, - s: typed_absy::TypedFunctionSymbol<'ast, T>, -) -> zir::ZirFunctionSymbol<'ast, T> { - match s { - typed_absy::TypedFunctionSymbol::Here(fun) => { - zir::ZirFunctionSymbol::Here(f.fold_function(fun)) - } - typed_absy::TypedFunctionSymbol::There(key, module) => { - zir::ZirFunctionSymbol::There(f.fold_function_key(key), module) - } // by default, do not fold modules recursively - typed_absy::TypedFunctionSymbol::Flat(flat) => zir::ZirFunctionSymbol::Flat(flat), - } + f.fold_struct_expression_inner( + &typed_absy::types::ConcreteStructType::try_from(e.ty().clone()).unwrap(), + e.into_inner(), + ) } pub fn fold_program<'ast, T: Field>( f: &mut Flattener, - p: typed_absy::TypedProgram<'ast, T>, + mut p: typed_absy::TypedProgram<'ast, T>, ) -> zir::ZirProgram<'ast, T> { + let main_module = p.modules.remove(&p.main).unwrap(); + + let main_function = main_module + .functions + .into_iter() + .find(|(key, _)| key.id == "main") + .unwrap() + .1; + + let main_function = match main_function { + typed_absy::TypedFunctionSymbol::Here(f) => f, + _ => unreachable!(), + }; + zir::ZirProgram { - modules: p - .modules - .into_iter() - .map(|(module_id, module)| (module_id, f.fold_module(module))) - .collect(), - main: p.main, + main: f.fold_function(main_function), } } diff --git a/zokrates_core/src/static_analysis/inline.rs b/zokrates_core/src/static_analysis/inline.rs deleted file mode 100644 index 4fad20a6..00000000 --- a/zokrates_core/src/static_analysis/inline.rs +++ /dev/null @@ -1,1387 +0,0 @@ -//! Module containing inlining for the typed AST -//! -//! @file inline.rs -//! @author Thibaut Schaeffer -//! @date 2019 - -//! Start from the `main` function in the `main` module and inline all calls except those to flat embeds -//! The resulting program has a single module, where we define a function for each flat embed and replace the function calls with the embeds found -//! during inlining by calls to these functions, to be resolved during flattening. - -//! The resulting program has a single module of the form - -//! def main(): -//! def _SHA_256_ROUND(): -//! def _UNPACK(): - -//! where any call in `main` must be to `_SHA_256_ROUND` or `_UNPACK` - -use crate::typed_absy::types::{FunctionKey, FunctionKeyHash, Type, UBitwidth}; -use crate::typed_absy::{folder::*, *}; -use std::collections::HashMap; -use zokrates_field::Field; - -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -struct Location<'ast> { - module: TypedModuleId, - key: FunctionKey<'ast>, -} - -impl<'ast> Location<'ast> { - fn module(&self) -> &TypedModuleId { - &self.module - } -} - -/// An inliner -#[derive(Debug)] -pub struct Inliner<'ast, T: Field> { - /// the modules in which to look for functions when inlining - modules: TypedModules<'ast, T>, - /// the current module we're visiting - location: Location<'ast>, - /// a buffer of statements to be added to the inlined statements - statement_buffer: Vec>, - /// the current call stack - stack: Vec<(TypedModuleId, FunctionKeyHash, usize)>, - /// the call count for each function - call_count: HashMap<(TypedModuleId, FunctionKey<'ast>), usize>, -} - -impl<'ast, T: Field> Inliner<'ast, T> { - fn with_modules_and_module_id_and_key>( - modules: TypedModules<'ast, T>, - module_id: S, - key: FunctionKey<'ast>, - ) -> Self { - Inliner { - modules, - location: Location { - module: module_id.into(), - key, - }, - statement_buffer: vec![], - stack: vec![], - call_count: HashMap::new(), - } - } - - pub fn inline(p: TypedProgram) -> TypedProgram { - let main_module_id = p.main; - - // get the main module - let main_module = p.modules.get(&main_module_id).unwrap().clone(); - - // get the main function in the main module - let (main_key, main) = main_module - .functions - .into_iter() - .find(|(k, _)| k.id == "main") - .unwrap(); - - // initialize an inliner over all modules, starting from the main module - let mut inliner = Inliner::with_modules_and_module_id_and_key( - p.modules, - main_module_id, - main_key.clone(), - ); - - // inline all calls in the main function, recursively - let main = inliner.fold_function_symbol(main); - - cfg_if::cfg_if! { - if #[cfg(feature = "bellman")] { - // define a function in the main module for the `sha256` embed - let sha256_round = crate::embed::FlatEmbed::Sha256Round; - let sha256_round_key = sha256_round.key::(); - } - } - - // define a function in the main module for the `unpack` embed - let unpack = crate::embed::FlatEmbed::Unpack(T::get_required_bits()); - let unpack_key = unpack.key::(); - - // define a function in the main module for the `u32_to_bits` embed - let u32_to_bits = crate::embed::FlatEmbed::U32ToBits; - let u32_to_bits_key = u32_to_bits.key::(); - - // define a function in the main module for the `u16_to_bits` embed - let u16_to_bits = crate::embed::FlatEmbed::U16ToBits; - let u16_to_bits_key = u16_to_bits.key::(); - - // define a function in the main module for the `u8_to_bits` embed - let u8_to_bits = crate::embed::FlatEmbed::U8ToBits; - let u8_to_bits_key = u8_to_bits.key::(); - - // define a function in the main module for the `u32_from_bits` embed - let u32_from_bits = crate::embed::FlatEmbed::U32FromBits; - let u32_from_bits_key = u32_from_bits.key::(); - - // define a function in the main module for the `u16_from_bits` embed - let u16_from_bits = crate::embed::FlatEmbed::U16FromBits; - let u16_from_bits_key = u16_from_bits.key::(); - - // define a function in the main module for the `u8_from_bits` embed - let u8_from_bits = crate::embed::FlatEmbed::U8FromBits; - let u8_from_bits_key = u8_from_bits.key::(); - - // return a program with a single module containing `main`, `_UNPACK`, and `_SHA256_ROUND - TypedProgram { - main: "main".into(), - modules: vec![( - "main".into(), - TypedModule { - functions: vec![ - #[cfg(feature = "bellman")] - (sha256_round_key, TypedFunctionSymbol::Flat(sha256_round)), - (unpack_key, TypedFunctionSymbol::Flat(unpack)), - (u32_from_bits_key, TypedFunctionSymbol::Flat(u32_from_bits)), - (u16_from_bits_key, TypedFunctionSymbol::Flat(u16_from_bits)), - (u8_from_bits_key, TypedFunctionSymbol::Flat(u8_from_bits)), - (u32_to_bits_key, TypedFunctionSymbol::Flat(u32_to_bits)), - (u16_to_bits_key, TypedFunctionSymbol::Flat(u16_to_bits)), - (u8_to_bits_key, TypedFunctionSymbol::Flat(u8_to_bits)), - (main_key, main), - ] - .into_iter() - .collect(), - }, - )] - .into_iter() - .collect(), - } - } - - /// try to inline a call to function with key `key` in the stack of `self` - /// if inlining succeeds, return the expressions returned by the function call - /// if inlining fails (as in the case of flat function symbols), return the arguments to the function call for further processing - fn try_inline_call( - &mut self, - key: &FunctionKey<'ast>, - expressions: Vec>, - ) -> Result>, (FunctionKey<'ast>, Vec>)> - { - // here we clone a function symbol, which is cheap except when it contains the function body, in which case we'd clone anyways - let res = match self.module().functions.get(&key).unwrap().clone() { - // if the function called is in the same module, we can go ahead and inline in this module - TypedFunctionSymbol::Here(function) => { - let (current_module, current_key) = - self.change_context(self.module_id().clone(), key.clone()); - - let module_id = self.module_id().clone(); - - // increase the number of calls for this function by one - let count = self - .call_count - .entry((self.module_id().clone(), key.clone())) - .and_modify(|i| *i += 1) - .or_insert(1); - // push this call to the stack - self.stack.push((module_id, key.hash(), *count)); - // add definitions for the inputs - let inputs_bindings: Vec<_> = function - .arguments - .iter() - .zip(expressions.clone()) - .map(|(a, e)| { - TypedStatement::Definition( - self.fold_assignee(TypedAssignee::Identifier(a.id.clone())), - e, - ) - }) - .collect(); - - self.statement_buffer.extend(inputs_bindings); - - // filter out the return statement and keep it aside - let (statements, mut ret): (Vec<_>, Vec<_>) = function - .statements - .into_iter() - .flat_map(|s| self.fold_statement(s)) - .partition(|s| match s { - TypedStatement::Return(..) => false, - _ => true, - }); - - // add all statements to the buffer - self.statement_buffer.extend(statements); - - // pop this call from the stack - self.stack.pop(); - - self.change_context(current_module, current_key); - - match ret.pop().unwrap() { - TypedStatement::Return(exprs) => Ok(exprs), - _ => unreachable!(""), - } - } - // if the function called is in some other module, we switch focus to that module and call the function locally there - TypedFunctionSymbol::There(function_key, module_id) => { - // switch focus to `module_id` - let (current_module, current_key) = - self.change_context(module_id, function_key.clone()); - // inline the call there - let res = self.try_inline_call(&function_key, expressions.clone())?; - // switch back focus - self.change_context(current_module, current_key); - Ok(res) - } - // if the function is a flat symbol, replace the call with a call to the local function we provide so it can be inlined in flattening - TypedFunctionSymbol::Flat(embed) => Err((embed.key::(), expressions.clone())), - }; - - res - } - - // Focus the inliner on another module with id `module_id` and return the current `module_id` - fn change_context( - &mut self, - module_id: TypedModuleId, - function_key: FunctionKey<'ast>, - ) -> (TypedModuleId, FunctionKey<'ast>) { - let current_module = std::mem::replace(&mut self.location.module, module_id); - let current_key = std::mem::replace(&mut self.location.key, function_key); - (current_module, current_key) - } - - fn module(&self) -> &TypedModule<'ast, T> { - self.modules.get(self.module_id()).unwrap() - } - - fn module_id(&self) -> &TypedModuleId { - self.location.module() - } -} - -impl<'ast, T: Field> Folder<'ast, T> for Inliner<'ast, T> { - // add extra statements before the modified statement - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - let folded = match s { - TypedStatement::MultipleDefinition(assignees, elist) => match elist { - TypedExpressionList::FunctionCall(key, exps, types) => { - let assignees: Vec<_> = assignees - .into_iter() - .map(|a| self.fold_assignee(a)) - .collect(); - let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - - match self.try_inline_call(&key, exps) { - Ok(ret) => assignees - .into_iter() - .zip(ret.into_iter()) - .map(|(a, e)| TypedStatement::Definition(a, e)) - .collect(), - Err((key, expressions)) => vec![TypedStatement::MultipleDefinition( - assignees, - TypedExpressionList::FunctionCall(key, expressions, types), - )], - } - } - }, - s => fold_statement(self, s), - }; - self.statement_buffer.drain(..).chain(folded).collect() - } - - // prefix all names with the stack - fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - Identifier { - stack: self.stack.clone(), - ..n - } - } - - // inline calls which return a field element - fn fold_field_expression( - &mut self, - e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { - match e { - FieldElementExpression::FunctionCall(key, exps) => { - let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - - match self.try_inline_call(&key, exps) { - Ok(mut ret) => match ret.pop().unwrap() { - TypedExpression::FieldElement(e) => e, - _ => unreachable!(), - }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call( - key.hash(), - *self - .call_count - .entry((self.module_id().clone(), embed_key.clone())) - .and_modify(|i| *i += 1) - .or_insert(1), - ), - version: 0, - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone()).into()], - TypedExpressionList::FunctionCall( - key.clone(), - expressions.clone(), - tys, - ), - )); - - FieldElementExpression::Identifier(id) - } - } - } - e => fold_field_expression(self, e), - } - } - - // inline calls which return a boolean element - fn fold_boolean_expression( - &mut self, - e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { - match e { - BooleanExpression::FunctionCall(key, exps) => { - let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - - match self.try_inline_call(&key, exps) { - Ok(mut ret) => match ret.pop().unwrap() { - TypedExpression::Boolean(e) => e, - _ => unreachable!(), - }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call( - key.hash(), - *self - .call_count - .entry((self.module_id().clone(), embed_key.clone())) - .and_modify(|i| *i += 1) - .or_insert(1), - ), - version: 0, - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone()).into()], - TypedExpressionList::FunctionCall( - key.clone(), - expressions.clone(), - tys, - ), - )); - - BooleanExpression::Identifier(id) - } - } - } - e => fold_boolean_expression(self, e), - } - } - - // inline calls which return an array - fn fold_array_expression_inner( - &mut self, - ty: &Type, - size: usize, - e: ArrayExpressionInner<'ast, T>, - ) -> ArrayExpressionInner<'ast, T> { - match e { - ArrayExpressionInner::FunctionCall(key, exps) => { - let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - - match self.try_inline_call(&key, exps) { - Ok(mut ret) => match ret.pop().unwrap() { - TypedExpression::Array(e) => e.into_inner(), - _ => unreachable!(), - }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call( - key.hash(), - *self - .call_count - .entry((self.module_id().clone(), embed_key.clone())) - .and_modify(|i| *i += 1) - .or_insert(1), - ), - version: 0, - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone()).into()], - TypedExpressionList::FunctionCall( - embed_key.clone(), - expressions.clone(), - tys, - ), - )); - - ArrayExpressionInner::Identifier(id) - } - } - } - // default - e => fold_array_expression_inner(self, ty, size, e), - } - } - - fn fold_struct_expression_inner( - &mut self, - ty: &StructType, - e: StructExpressionInner<'ast, T>, - ) -> StructExpressionInner<'ast, T> { - match e { - StructExpressionInner::FunctionCall(key, exps) => { - let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - - match self.try_inline_call(&key, exps) { - Ok(mut ret) => match ret.pop().unwrap() { - TypedExpression::Struct(e) => e.into_inner(), - _ => unreachable!(), - }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call( - key.hash(), - *self - .call_count - .entry((self.module_id().clone(), embed_key.clone())) - .and_modify(|i| *i += 1) - .or_insert(1), - ), - version: 0, - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone()).into()], - TypedExpressionList::FunctionCall( - key.clone(), - expressions.clone(), - tys, - ), - )); - - StructExpressionInner::Identifier(id) - } - } - } - // default - e => fold_struct_expression_inner(self, ty, e), - } - } - - fn fold_uint_expression_inner( - &mut self, - size: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { - match e { - UExpressionInner::FunctionCall(key, exps) => { - let exps: Vec<_> = exps.into_iter().map(|e| self.fold_expression(e)).collect(); - - match self.try_inline_call(&key, exps) { - Ok(mut ret) => match ret.pop().unwrap() { - TypedExpression::Uint(e) => e.into_inner(), - _ => unreachable!(), - }, - Err((embed_key, expressions)) => { - let tys = key.signature.outputs.clone(); - let id = Identifier { - id: CoreIdentifier::Call( - key.hash(), - *self - .call_count - .entry((self.module_id().clone(), embed_key.clone())) - .and_modify(|i| *i += 1) - .or_insert(1), - ), - version: 0, - stack: self.stack.clone(), - }; - self.statement_buffer - .push(TypedStatement::MultipleDefinition( - vec![Variable::with_id_and_type(id.clone(), tys[0].clone()).into()], - TypedExpressionList::FunctionCall( - embed_key.clone(), - expressions.clone(), - tys, - ), - )); - - UExpressionInner::Identifier(id) - } - } - } - // default - e => fold_uint_expression_inner(self, size, e), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::typed_absy::types::{FunctionKey, Signature, Type}; - use std::path::PathBuf; - use zokrates_field::Bn128Field; - - #[test] - fn call_other_module_without_variables() { - // // main - // from "foo" import foo - // def main() -> field: - // return foo() - // - // // foo - // def foo() -> field: - // return 42 - // - // - // // inlined - // def main() -> field: - // return 42 - - let main = TypedModule { - functions: vec![ - ( - FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::FunctionCall( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - vec![], - ) - .into(), - ])], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }), - ), - ( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::There( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - "foo".into(), - ), - ), - ] - .into_iter() - .collect(), - }; - - let foo = TypedModule { - functions: vec![( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(42)).into(), - ])], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }), - )] - .into_iter() - .collect(), - }; - - let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules, - }; - - let program = Inliner::inline(program); - - assert_eq!(program.modules.len(), 1); - assert_eq!( - program - .modules - .get(&PathBuf::from("main")) - .unwrap() - .functions - .get( - &FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])) - ) - .unwrap(), - &TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(42)).into(), - ])], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }) - ); - } - - #[test] - fn call_other_module_with_variables() { - // // main - // from "foo" import foo - // def main(field a) -> field: - // return a * foo(a) - // - // // foo - // def foo(field a) -> field: - // return a * a - // - // - // // inlined - // def main(a) -> field: - // field a_0 = a - // return a * a_0 * a_0 - - let main = TypedModule { - functions: vec![ - ( - FunctionKey::with_id("main").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter::private(Variable::field_element("a"))], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Mult( - box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - vec![FieldElementExpression::Identifier("a".into()).into()], - ), - ) - .into(), - ])], - signature: Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - }), - ), - ( - FunctionKey::with_id("foo").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::There( - FunctionKey::with_id("foo").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - "foo".into(), - ), - ), - ] - .into_iter() - .collect(), - }; - - let foo = TypedModule { - functions: vec![( - FunctionKey::with_id("foo").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter::private(Variable::field_element("a"))], - statements: vec![TypedStatement::Return(vec![FieldElementExpression::Mult( - box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::Identifier("a".into()), - ) - .into()])], - signature: Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - }), - )] - .into_iter() - .collect(), - }; - - let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] - .into_iter() - .collect(); - - let program: TypedProgram = TypedProgram { - main: "main".into(), - modules, - }; - - let program = Inliner::inline(program); - - assert_eq!(program.modules.len(), 1); - - let stack = vec![( - "foo".into(), - FunctionKey::with_id("foo") - .signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ) - .hash(), - 1, - )]; - - assert_eq!( - program - .modules - .get(&PathBuf::from("main")) - .unwrap() - .functions - .get( - &FunctionKey::with_id("main").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]) - ) - ) - .unwrap(), - &TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter::private(Variable::field_element("a"))], - statements: vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").stack(stack.clone()) - )), - FieldElementExpression::Identifier("a".into()).into() - ), - TypedStatement::Return(vec![FieldElementExpression::Mult( - box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::Mult( - box FieldElementExpression::Identifier( - Identifier::from("a").stack(stack.clone()) - ), - box FieldElementExpression::Identifier( - Identifier::from("a").stack(stack.clone()) - ) - ) - ) - .into(),]) - ], - signature: Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - }) - ); - } - - #[test] - fn only_memoize_in_same_function() { - // // foo - // def foo(field a) -> field: - // return a - - // // main - // def main(field a) -> field: - // field b = foo(a) + bar(a) - // return b - // - // def bar(field a) -> field: - // return foo(a) - - // inlined - // def main(field a) -> field - // field _0 = a + a - // return _0 - - let signature = Signature::new() - .outputs(vec![Type::FieldElement]) - .inputs(vec![Type::FieldElement]); - - let main: TypedModule = TypedModule { - functions: vec![ - ( - FunctionKey::with_id("main").signature( - Signature::new() - .outputs(vec![Type::FieldElement]) - .inputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter { - id: Variable::field_element("a"), - private: true, - }], - statements: vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("b")), - FieldElementExpression::Add( - box FieldElementExpression::FunctionCall( - FunctionKey::with_id("foo").signature(signature.clone()), - vec![FieldElementExpression::Identifier("a".into()).into()], - ), - box FieldElementExpression::FunctionCall( - FunctionKey::with_id("bar").signature(signature.clone()), - vec![FieldElementExpression::Identifier("a".into()).into()], - ), - ) - .into(), - ), - TypedStatement::Return(vec![FieldElementExpression::Identifier( - "b".into(), - ) - .into()]), - ], - signature: signature.clone(), - }), - ), - ( - FunctionKey::with_id("bar").signature(signature.clone()), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter { - id: Variable::field_element("a"), - private: true, - }], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::FunctionCall( - FunctionKey::with_id("foo").signature(signature.clone()), - vec![FieldElementExpression::Identifier("a".into()).into()], - ) - .into(), - ])], - signature: signature.clone(), - }), - ), - ( - FunctionKey::with_id("foo").signature(signature.clone()), - TypedFunctionSymbol::There( - FunctionKey::with_id("foo").signature(signature.clone()), - "foo".into(), - ), - ), - ] - .into_iter() - .collect(), - }; - - let foo: TypedModule = TypedModule { - functions: vec![( - FunctionKey::with_id("foo").signature(signature.clone()), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter { - id: Variable::field_element("a"), - private: true, - }], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier("a".into()).into(), - ])], - signature: signature.clone(), - }), - )] - .into_iter() - .collect(), - }; - - let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules, - }; - - let program = Inliner::inline(program); - - assert_eq!(program.modules.len(), 1); - assert_eq!( - program - .modules - .get(&PathBuf::from("main")) - .unwrap() - .functions - .get(&FunctionKey::with_id("main").signature(signature.clone())) - .unwrap(), - &TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter { - id: Variable::field_element("a"), - private: true, - }], - statements: vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").stack(vec![( - "foo".into(), - FunctionKey::with_id("foo") - .signature(signature.clone()) - .hash(), - 1 - )]) - )), - FieldElementExpression::Identifier("a".into()).into() - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").stack(vec![( - "main".into(), - FunctionKey::with_id("bar") - .signature(signature.clone()) - .hash(), - 1 - )]) - )), - FieldElementExpression::Identifier("a".into()).into() - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").stack(vec![ - ( - "main".into(), - FunctionKey::with_id("bar") - .signature(signature.clone()) - .hash(), - 1 - ), - ( - "foo".into(), - FunctionKey::with_id("foo") - .signature(signature.clone()) - .hash(), - 2 - ) - ]) - )), - FieldElementExpression::Identifier(Identifier::from("a").stack(vec![( - "main".into(), - FunctionKey::with_id("bar").signature(signature.clone()).hash(), - 1 - )])) - .into() - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("b")), - FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from("a").stack( - vec![( - "foo".into(), - FunctionKey::with_id("foo").signature(signature.clone()).hash(), - 1 - )] - )), - box FieldElementExpression::Identifier(Identifier::from("a").stack( - vec![ - ( - "main".into(), - FunctionKey::with_id("bar") - .signature(signature.clone()) - .hash(), - 1 - ), - ( - "foo".into(), - FunctionKey::with_id("foo") - .signature(signature.clone()) - .hash(), - 2 - ) - ] - )) - ) - .into() - ), - TypedStatement::Return(vec![ - FieldElementExpression::Identifier("b".into()).into(), - ]) - ], - signature: signature.clone(), - }) - ); - } - - #[test] - fn multi_def_from_other_module() { - // // foo - // def foo() -> field: - // return 42 - - // // main - // def main() -> field: - // field b = foo() - // return b - - // inlined - // def main() -> field - // field _0 = 42 - // return _0 - - let main = TypedModule { - functions: vec![ - ( - FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![ - TypedStatement::MultipleDefinition( - vec![Variable::field_element("a").into()], - TypedExpressionList::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new().outputs(vec![Type::FieldElement]), - ), - vec![], - vec![Type::FieldElement], - ), - ), - TypedStatement::Return(vec![FieldElementExpression::Identifier( - "a".into(), - ) - .into()]), - ], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }), - ), - ( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::There( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - "foo".into(), - ), - ), - ] - .into_iter() - .collect(), - }; - - let foo = TypedModule { - functions: vec![( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(42)).into(), - ])], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }), - )] - .into_iter() - .collect(), - }; - - let modules: HashMap<_, _> = vec![("main".into(), main), ("foo".into(), foo)] - .into_iter() - .collect(); - - let program = TypedProgram { - main: "main".into(), - modules, - }; - - let program = Inliner::inline(program); - - assert_eq!(program.modules.len(), 1); - assert_eq!( - program - .modules - .get(&PathBuf::from("main")) - .unwrap() - .functions - .get( - &FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])) - ) - .unwrap(), - &TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(42)).into() - ), - TypedStatement::Return(vec![ - FieldElementExpression::Identifier("a".into()).into(), - ]) - ], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }) - ); - } - - #[test] - fn multi_def_from_same_module() { - // // main - // def foo() -> field: - // return 42 - // def main() -> field: - // field a = foo() - // return a - - // inlined - // def main() -> field - // field _0 = 42 - // return _0 - - let main = TypedModule { - functions: vec![ - ( - FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![ - TypedStatement::MultipleDefinition( - vec![Variable::field_element("a").into()], - TypedExpressionList::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new().outputs(vec![Type::FieldElement]), - ), - vec![], - vec![Type::FieldElement], - ), - ), - TypedStatement::Return(vec![FieldElementExpression::Identifier( - "a".into(), - ) - .into()]), - ], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }), - ), - ( - FunctionKey::with_id("foo") - .signature(Signature::new().outputs(vec![Type::FieldElement])), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Number(Bn128Field::from(42)).into(), - ])], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }), - ), - ] - .into_iter() - .collect(), - }; - - let modules: HashMap<_, _> = vec![("main".into(), main)].into_iter().collect(); - - let program = TypedProgram { - main: "main".into(), - modules, - }; - - let program = Inliner::inline(program); - - assert_eq!(program.modules.len(), 1); - assert_eq!( - program - .modules - .get(&PathBuf::from("main")) - .unwrap() - .functions - .get( - &FunctionKey::with_id("main") - .signature(Signature::new().outputs(vec![Type::FieldElement])) - ) - .unwrap(), - &TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - statements: vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(42)).into() - ), - TypedStatement::Return(vec![ - FieldElementExpression::Identifier("a".into()).into(), - ]) - ], - signature: Signature::new().outputs(vec![Type::FieldElement]), - }) - ); - } - - #[test] - fn recursive_call_in_other_module() { - // // main - // def main(field a) -> field: - // return id(id(a)) - - // // id - // def main(field a) -> field - // return a - - // inlined - // def main(field a) -> field - // id_main_0_a = a - // id_main_1_a = id_main_0_a - // return id_main_1_a - - let main = TypedModule { - functions: vec![ - ( - FunctionKey::with_id("main").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter::private(Variable::field_element("a"))], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::FunctionCall( - FunctionKey::with_id("id").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - vec![FieldElementExpression::FunctionCall( - FunctionKey::with_id("id").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - vec![FieldElementExpression::Identifier("a".into()).into()], - ) - .into()], - ) - .into(), - ])], - signature: Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - }), - ), - ( - FunctionKey::with_id("id").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::There( - FunctionKey::with_id("main").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - "id".into(), - ), - ), - ] - .into_iter() - .collect(), - }; - - let id = TypedModule { - functions: vec![( - FunctionKey::with_id("main").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter::private(Variable::field_element("a"))], - statements: vec![TypedStatement::Return(vec![ - FieldElementExpression::Identifier("a".into()).into(), - ])], - signature: Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - }), - )] - .into_iter() - .collect(), - }; - - let modules = vec![("main".into(), main), ("id".into(), id)] - .into_iter() - .collect(); - - let program: TypedProgram = TypedProgram { - main: "main".into(), - modules, - }; - - let program = Inliner::inline(program); - - let stack0 = vec![( - "id".into(), - FunctionKey::with_id("main") - .signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ) - .hash(), - 1, - )]; - let stack1 = vec![( - "id".into(), - FunctionKey::with_id("main") - .signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ) - .hash(), - 2, - )]; - - assert_eq!(program.modules.len(), 1); - assert_eq!( - program - .modules - .get(&PathBuf::from("main")) - .unwrap() - .functions - .get( - &FunctionKey::with_id("main").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]) - ) - ) - .unwrap(), - &TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![Parameter::private(Variable::field_element("a"))], - statements: vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").stack(stack0.clone()) - )), - FieldElementExpression::Identifier("a".into()).into() - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").stack(stack1.clone()) - )), - FieldElementExpression::Identifier( - Identifier::from("a").stack(stack0.clone()) - ) - .into() - ), - TypedStatement::Return(vec![FieldElementExpression::Identifier( - Identifier::from("a").stack(stack1.clone()) - ) - .into(),]) - ], - signature: Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - }) - ); - } -} diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 83fb69ca..7768ba56 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -4,69 +4,93 @@ //! @author Thibaut Schaeffer //! @date 2018 +mod bounds_checker; mod flat_propagation; mod flatten_complex_types; -mod inline; -mod propagate_unroll; mod propagation; mod redefinition; -mod return_binder; +mod reducer; mod uint_optimizer; mod unconstrained_vars; -mod unroll; mod variable_read_remover; mod variable_write_remover; +use self::bounds_checker::BoundsChecker; use self::flatten_complex_types::Flattener; -use self::inline::Inliner; -use self::propagate_unroll::PropagatedUnroller; use self::propagation::Propagator; use self::redefinition::RedefinitionOptimizer; -use self::return_binder::ReturnBinder; +use self::reducer::reduce_program; use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_read_remover::VariableReadRemover; use self::variable_write_remover::VariableWriteRemover; use crate::flat_absy::FlatProg; use crate::ir::Prog; -use crate::typed_absy::TypedProgram; +use crate::typed_absy::{abi::Abi, TypedProgram}; use crate::zir::ZirProgram; +use std::fmt; use zokrates_field::Field; pub trait Analyse { fn analyse(self) -> Self; } +#[derive(Debug)] +pub enum Error { + Reducer(self::reducer::Error), + OutOfBounds(self::bounds_checker::Error), + Propagation(self::propagation::Error), +} + +impl From for Error { + fn from(e: self::reducer::Error) -> Self { + Error::Reducer(e) + } +} + +impl From for Error { + fn from(e: bounds_checker::Error) -> Self { + Error::OutOfBounds(e) + } +} + +impl From for Error { + fn from(e: propagation::Error) -> Self { + Error::Propagation(e) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Reducer(e) => write!(f, "{}", e), + Error::OutOfBounds(e) => write!(f, "{}", e), + Error::Propagation(e) => write!(f, "{}", e), + } + } +} impl<'ast, T: Field> TypedProgram<'ast, T> { - pub fn analyse(self) -> ZirProgram<'ast, T> { - // propagated unrolling - let r = PropagatedUnroller::unroll(self).unwrap_or_else(|e| panic!("{}", e)); + pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> { + let r = reduce_program(self).map_err(Error::from)?; - // return binding - let r = ReturnBinder::bind(r); - - // inline - let r = Inliner::inline(r); + let abi = r.abi(); // propagate - let r = Propagator::propagate(r); - + let r = Propagator::propagate(r).map_err(Error::from)?; // optimize redefinitions let r = RedefinitionOptimizer::optimize(r); - // remove assignment to variable index let r = VariableWriteRemover::apply(r); - // remove variable access to complex types let r = VariableReadRemover::apply(r); - + // check array accesses are in bounds + let r = BoundsChecker::check(r).map_err(Error::from)?; // convert to zir, removing complex types let zir = Flattener::flatten(r); - // optimize uint expressions let zir = UintOptimizer::optimize(zir); - zir + Ok((zir, abi)) } } @@ -78,7 +102,6 @@ impl Analyse for FlatProg { impl Analyse for Prog { fn analyse(self) -> Self { - let r = UnconstrainedVariableDetector::detect(self); - r + UnconstrainedVariableDetector::detect(self) } } diff --git a/zokrates_core/src/static_analysis/propagate_unroll.rs b/zokrates_core/src/static_analysis/propagate_unroll.rs deleted file mode 100644 index 67354b53..00000000 --- a/zokrates_core/src/static_analysis/propagate_unroll.rs +++ /dev/null @@ -1,236 +0,0 @@ -//! Module containing iterative unrolling in order to unroll nested loops with variable bounds -//! -//! For example: -//! ```zokrates -//! for field i in 0..5 do -//! for field j in i..5 do -//! // -//! endfor -//! endfor -//! ``` -//! -//! We can unroll the outer loop, but to unroll the inner one we need to propagate the value of `i` to the lower bound of the loop -//! -//! This module does exactly that: -//! - unroll the outter loop, detecting that it cannot unroll the inner one as the lower `i` bound isn't constant -//! - apply constant propagation to the program, *not visiting statements of loops whose bounds are not constant yet* -//! - unroll again, this time the 5 inner loops all have constant bounds -//! -//! In the case that a loop bound cannot be reduced to a constant, we detect it by noticing that the unroll does -//! not make progress anymore. - -use crate::static_analysis::propagation::Propagator; -use crate::static_analysis::unroll::{Output, Unroller}; -use crate::typed_absy::TypedProgram; -use zokrates_field::Field; - -pub struct PropagatedUnroller; - -impl PropagatedUnroller { - pub fn unroll<'ast, T: Field>( - p: TypedProgram<'ast, T>, - ) -> Result, &'static str> { - let mut blocked_at = None; - - // unroll a first time, retrieving whether the unroll is complete - let mut unrolled = Unroller::unroll(p); - - loop { - // conditions to exit the loop - unrolled = match unrolled { - Output::Complete(p) => return Ok(p), - Output::Incomplete(next, index) => { - if Some(index) == blocked_at { - return Err("Loop unrolling failed. This happened because a loop bound is not constant"); - } else { - // update the index where we blocked - blocked_at = Some(index); - - // propagate - let propagated = Propagator::propagate_verbose(next); - - // unroll - Unroller::unroll(propagated) - } - } - }; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::typed_absy::types::{FunctionKey, Signature}; - use crate::typed_absy::*; - use zokrates_field::Bn128Field; - - #[test] - fn detect_non_constant_bound() { - let loops = vec![TypedStatement::For( - Variable::field_element("i"), - FieldElementExpression::Identifier("i".into()), - FieldElementExpression::Number(Bn128Field::from(2)), - vec![], - )]; - - let statements = loops; - - let p = TypedProgram { - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - FunctionKey::with_id("main"), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - signature: Signature::new(), - statements, - }), - )] - .into_iter() - .collect(), - }, - )] - .into_iter() - .collect(), - main: "main".into(), - }; - - assert!(PropagatedUnroller::unroll(p).is_err()); - } - - #[test] - fn for_loop() { - // for field i in 0..2 - // for field j in i..2 - // field foo = i + j - - // should be unrolled to - // i_0 = 0 - // j_0 = 0 - // foo_0 = i_0 + j_0 - // j_1 = 1 - // foo_1 = i_0 + j_1 - // i_1 = 1 - // j_2 = 1 - // foo_2 = i_1 + j_1 - - let s = TypedStatement::For( - Variable::field_element("i"), - FieldElementExpression::Number(Bn128Field::from(0)), - FieldElementExpression::Number(Bn128Field::from(2)), - vec![TypedStatement::For( - Variable::field_element("j"), - FieldElementExpression::Identifier("i".into()), - FieldElementExpression::Number(Bn128Field::from(2)), - vec![ - TypedStatement::Declaration(Variable::field_element("foo")), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("foo")), - FieldElementExpression::Add( - box FieldElementExpression::Identifier("i".into()), - box FieldElementExpression::Identifier("j".into()), - ) - .into(), - ), - ], - )], - ); - - let expected_statements = vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("i").version(0), - )), - FieldElementExpression::Number(Bn128Field::from(0)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("j").version(0), - )), - FieldElementExpression::Number(Bn128Field::from(0)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("foo").version(0), - )), - FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from("i").version(0)), - box FieldElementExpression::Identifier(Identifier::from("j").version(0)), - ) - .into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("j").version(1), - )), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("foo").version(1), - )), - FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from("i").version(0)), - box FieldElementExpression::Identifier(Identifier::from("j").version(1)), - ) - .into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("i").version(1), - )), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("j").version(2), - )), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("foo").version(2), - )), - FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from("i").version(1)), - box FieldElementExpression::Identifier(Identifier::from("j").version(2)), - ) - .into(), - ), - ]; - - let p = TypedProgram { - modules: vec![( - "main".into(), - TypedModule { - functions: vec![( - FunctionKey::with_id("main"), - TypedFunctionSymbol::Here(TypedFunction { - arguments: vec![], - signature: Signature::new(), - statements: vec![s], - }), - )] - .into_iter() - .collect(), - }, - )] - .into_iter() - .collect(), - main: "main".into(), - }; - - let statements = match PropagatedUnroller::unroll(p).unwrap().modules - [std::path::Path::new("main")] - .functions[&FunctionKey::with_id("main")] - .clone() - { - TypedFunctionSymbol::Here(f) => f.statements, - _ => unreachable!(), - }; - - assert_eq!(statements, expected_statements); - } -} diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 6626cd22..92970cff 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -1,60 +1,62 @@ //! Module containing constant propagation for the typed AST //! -//! On top of the usual behavior of removing statements which assign a constant to a variable (as the variable can simply be -//! substituted for the constant whenever used), we provide a `verbose` mode which does not remove such statements. This is done -//! as for partial passes which do not visit the whole program, the variables being defined may be be used in parts of the program -//! that are not visited. Keeping the statements is semantically equivalent and enables rebuilding the set of constants at the -//! next pass. +//! Constant propagation on the SSA program. The constants map can be passed by the caller to allow for many passes to use +//! the same constants. //! //! @file propagation.rs //! @author Thibaut Schaeffer //! @date 2018 -use crate::typed_absy::folder::*; +use crate::embed::FlatEmbed; +use crate::typed_absy::result_folder::*; use crate::typed_absy::types::Type; use crate::typed_absy::*; use std::collections::HashMap; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; +use std::fmt; use zokrates_field::Field; -pub struct Propagator<'ast, T: Field> { - // constants keeps track of constant expressions - // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is - constants: HashMap, TypedExpression<'ast, T>>, - // the verbose mode doesn't remove statements which assign constants to variables - // it's required when using propagation in combination with unrolling - verbose: bool, +type Constants<'ast, T> = HashMap, TypedExpression<'ast, T>>; + +#[derive(Debug, PartialEq)] +pub enum Error { + Type(String), } -impl<'ast, T: Field> Propagator<'ast, T> { - fn verbose() -> Self { - Propagator { - constants: HashMap::new(), - verbose: true, +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Type(s) => write!(f, "{}", s), } } +} + +pub struct Propagator<'ast, 'a, T: Field> { + // constants keeps track of constant expressions + // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is + constants: &'a mut Constants<'ast, T>, +} + +impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { + pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self { + Propagator { constants } + } + + pub fn propagate(p: TypedProgram<'ast, T>) -> Result, Error> { + let mut constants = Constants::new(); - fn new() -> Self { Propagator { - constants: HashMap::new(), - verbose: false, + constants: &mut constants, } - } - - pub fn propagate(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - Propagator::new().fold_program(p) - } - - pub fn propagate_verbose(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - Propagator::verbose().fold_program(p) + .fold_program(p) } // get a mutable reference to the constant corresponding to a given assignee if any, otherwise // return the identifier at the root of this assignee - fn try_get_constant_mut<'a>( + fn try_get_constant_mut<'b>( &mut self, - assignee: &'a TypedAssignee<'ast, T>, - ) -> Result<(&'a Variable<'ast>, &mut TypedExpression<'ast, T>), &'a Variable<'ast>> { + assignee: &'b TypedAssignee<'ast, T>, + ) -> Result<(&'b Variable<'ast, T>, &mut TypedExpression<'ast, T>), &'b Variable<'ast, T>> { match assignee { TypedAssignee::Identifier(var) => self .constants @@ -63,19 +65,20 @@ impl<'ast, T: Field> Propagator<'ast, T> { .unwrap_or(Err(var)), TypedAssignee::Select(box assignee, box index) => { match self.try_get_constant_mut(&assignee) { - Ok((v, c)) => match index { - FieldElementExpression::Number(n) => { - let n = n.to_dec_string().parse::().unwrap(); - - match c { - TypedExpression::Array(a) => match a.as_inner_mut() { - ArrayExpressionInner::Value(value) => Ok((v, &mut value[n])), + Ok((variable, constant)) => match index.as_inner() { + UExpressionInner::Value(n) => match constant { + TypedExpression::Array(a) => match a.as_inner_mut() { + ArrayExpressionInner::Value(value) => match value.0[*n as usize] { + TypedExpressionOrSpread::Expression(ref mut e) => { + Ok((variable, e)) + } _ => unreachable!(), }, _ => unreachable!(), - } - } - _ => Err(v), + }, + _ => unreachable!(), + }, + _ => Err(variable), }, e => e, } @@ -107,86 +110,200 @@ impl<'ast, T: Field> Propagator<'ast, T> { } } -fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool { +fn is_constant(e: &TypedExpression) -> bool { match e { TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true, TypedExpression::Boolean(BooleanExpression::Value(..)) => true, TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), + ArrayExpressionInner::Value(v) => v.0.iter().all(|e| match e { + TypedExpressionOrSpread::Expression(e) => is_constant(e), + _ => false, + }), + ArrayExpressionInner::Slice(box a, box from, box to) => { + is_constant(&from.clone().into()) + && is_constant(&to.clone().into()) + && is_constant(&a.clone().into()) + } + ArrayExpressionInner::Repeat(box e, box count) => { + is_constant(&count.clone().into()) && is_constant(&e) + } _ => false, }, TypedExpression::Struct(a) => match a.as_inner() { StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), _ => false, }, - TypedExpression::Uint(a) => match a.as_inner() { - UExpressionInner::Value(..) => true, - _ => false, - }, + TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)), _ => false, } } -impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { - fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { - self.constants = HashMap::new(); +fn remove_spreads(e: TypedExpression) -> TypedExpression { + fn remove_spreads_aux(e: TypedExpressionOrSpread) -> Vec> { + match e { + TypedExpressionOrSpread::Expression(e) => vec![e], + TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() { + ArrayExpressionInner::Value(v) => { + v.into_iter().flat_map(remove_spreads_aux).collect() + } + _ => unimplemented!(), + }, + } + } + + match e { + TypedExpression::Array(a) => { + let array_ty = a.get_array_type(); + + match a.into_inner() { + ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( + v.into_iter() + .flat_map(remove_spreads_aux) + .map(|e| e.into()) + .collect::>() + .into(), + ) + .annotate(*array_ty.ty, array_ty.size) + .into(), + ArrayExpressionInner::Slice(box a, box from, box to) => { + let from = match from.into_inner() { + UExpressionInner::Value(from) => from as usize, + _ => unreachable!(), + }; + + let to = match to.into_inner() { + UExpressionInner::Value(to) => to as usize, + _ => unreachable!(), + }; + + let v = match a.into_inner() { + ArrayExpressionInner::Value(v) => v, + _ => unreachable!(), + }; + + ArrayExpressionInner::Value( + v.into_iter() + .flat_map(remove_spreads_aux) + .map(|e| e.into()) + .enumerate() + .filter(|(index, _)| index >= &from && index < &to) + .map(|(_, e)| e) + .collect::>() + .into(), + ) + .annotate(*array_ty.ty, array_ty.size) + .into() + } + ArrayExpressionInner::Repeat(box e, box count) => { + let count = match count.into_inner() { + UExpressionInner::Value(from) => from as usize, + _ => unreachable!(), + }; + + let e = remove_spreads(e); + + ArrayExpressionInner::Value( + vec![TypedExpressionOrSpread::Expression(e); count].into(), + ) + .annotate(*array_ty.ty, array_ty.size) + .into() + } + _ => unreachable!(), + } + } + e => e, + } +} + +impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { + type Error = Error; + + fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result, Error> { + let main = p.main.clone(); + + Ok(TypedProgram { + modules: p + .modules + .into_iter() + .map(|(module_id, module)| { + if module_id == main { + self.fold_module(module).map(|m| (module_id, m)) + } else { + Ok((module_id, module)) + } + }) + .collect::>()?, + main: p.main, + }) + } + + fn fold_module(&mut self, m: TypedModule<'ast, T>) -> Result, Error> { + Ok(TypedModule { + functions: m + .functions + .into_iter() + .map(|(key, fun)| { + if key.id == "main" { + self.fold_function_symbol(fun).map(|f| (key, f)) + } else { + Ok((key, fun)) + } + }) + .collect::>()?, + }) + } + + fn fold_function( + &mut self, + f: TypedFunction<'ast, T>, + ) -> Result, Error> { fold_function(self, f) } - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - let res = match s { - TypedStatement::Declaration(v) => vec![TypedStatement::Declaration(v)], - TypedStatement::Return(expressions) => vec![TypedStatement::Return( - expressions - .into_iter() - .map(|e| self.fold_expression(e)) - .collect(), - )], + fn fold_statement( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Error> { + match s { // propagation to the defined variable if rhs is a constant TypedStatement::Definition(assignee, expr) => { - let expr = self.fold_expression(expr); - let assignee = self.fold_assignee(assignee); + let expr = self.fold_expression(expr)?; + let assignee = self.fold_assignee(assignee)?; + + if let (Ok(a), Ok(e)) = ( + ConcreteType::try_from(assignee.get_type()), + ConcreteType::try_from(expr.get_type()), + ) { + if a != e { + return Err(Error::Type(format!( + "Cannot assign {} of type {} to {} of type {}", + expr, e, assignee, a + ))); + } + }; if is_constant(&expr) { - let verbose = self.verbose; - match assignee { - TypedAssignee::Identifier(var) => match verbose { - true => { - assert!(self - .constants - .insert(var.id.clone(), expr.clone()) - .is_none()); - vec![TypedStatement::Definition( - TypedAssignee::Identifier(var), - expr, - )] - } - false => { - assert!(self.constants.insert(var.id, expr).is_none()); + TypedAssignee::Identifier(var) => { + let expr = remove_spreads(expr); - vec![] - } - }, + assert!(self.constants.insert(var.id, expr).is_none()); + + Ok(vec![]) + } assignee => match self.try_get_constant_mut(&assignee) { - Ok((_, c)) => match verbose { - true => { - *c = expr.clone(); - vec![TypedStatement::Definition(assignee, expr)] - } - false => { - *c = expr; - vec![] - } - }, + Ok((_, c)) => { + *c = remove_spreads(expr); + Ok(vec![]) + } Err(v) => match self.constants.remove(&v.id) { // invalidate the cache for this identifier, and define the latest // version of the constant in the program, if any - Some(c) => vec![ + Some(c) => Ok(vec![ TypedStatement::Definition(v.clone().into(), c), TypedStatement::Definition(assignee, expr), - ], - None => vec![TypedStatement::Definition(assignee, expr)], + ]), + None => Ok(vec![TypedStatement::Definition(assignee, expr)]), }, }, } @@ -198,68 +315,66 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { .unwrap_or_else(|v| v); match self.constants.remove(&v.id) { - Some(c) => vec![ + Some(c) => Ok(vec![ TypedStatement::Definition(v.clone().into(), c), TypedStatement::Definition(assignee, expr), - ], - None => vec![TypedStatement::Definition(assignee, expr)], + ]), + None => Ok(vec![TypedStatement::Definition(assignee, expr)]), } } } - // propagate the boolean - TypedStatement::Assertion(e) => { - // could stop execution here if condition is known to fail - vec![TypedStatement::Assertion(self.fold_boolean_expression(e))] - } - // only loops with variable bounds are expected here - // we stop propagation here as constants maybe be modified inside the loop body - // which we do not visit + // we do not visit the for-loop statements TypedStatement::For(v, from, to, statements) => { - let from = self.fold_field_expression(from); - let to = self.fold_field_expression(to); + let from = self.fold_uint_expression(from)?; + let to = self.fold_uint_expression(to)?; - // invalidate the constants map as any constant could be modified inside the loop body, which we don't visit - self.constants.clear(); - - vec![TypedStatement::For(v, from, to, statements)] + Ok(vec![TypedStatement::For(v, from, to, statements)]) } TypedStatement::MultipleDefinition(assignees, expression_list) => { let assignees: Vec> = assignees .into_iter() .map(|a| self.fold_assignee(a)) - .collect(); - let expression_list = self.fold_expression_list(expression_list); + .collect::>()?; + let expression_list = self.fold_expression_list(expression_list)?; - match expression_list { - TypedExpressionList::FunctionCall(key, arguments, types) => { + let statements = match expression_list { + TypedExpressionList::EmbedCall(embed, generics, arguments, types) => { let arguments: Vec<_> = arguments .into_iter() .map(|a| self.fold_expression(a)) - .collect(); + .collect::>()?; + + let types = types + .into_iter() + .map(|t| self.fold_type(t)) + .collect::>()?; fn process_u_from_bits<'ast, T: Field>( variables: Vec>, - arguments: Vec>, + mut arguments: Vec>, bitwidth: UBitwidth, ) -> TypedExpression<'ast, T> { assert_eq!(variables.len(), 1); assert_eq!(arguments.len(), 1); - use std::convert::TryInto; + let argument = arguments.pop().unwrap(); - match ArrayExpression::try_from(arguments[0].clone()) + let argument = remove_spreads(argument); + + match ArrayExpression::try_from(argument) .unwrap() .into_inner() { - ArrayExpressionInner::Value(v) => { - assert_eq!(v.len(), bitwidth.to_usize()); + ArrayExpressionInner::Value(v) => UExpressionInner::Value( v.into_iter() .map(|v| match v { - TypedExpression::Boolean( - BooleanExpression::Value(v), + TypedExpressionOrSpread::Expression( + TypedExpression::Boolean( + BooleanExpression::Value(v), + ), ) => v, - _ => unreachable!("should be a boolean value"), + _ => unreachable!("Should be a constant boolean expression. Spreads are not expected here, as in their presence the argument isn't constant"), }) .enumerate() .fold(0, |acc, (i, v)| { @@ -275,9 +390,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { }), ) .annotate(bitwidth) - .into() - } - _ => unreachable!("should be an array value"), + .into(), + v => unreachable!("should be an array value, found {}", v), } } @@ -299,7 +413,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { for i in (0..bitwidth as u32).rev() { if 2u128.pow(i) <= num { - num = num - 2u128.pow(i); + num -= 2u128.pow(i); res.push(true); } else { res.push(false); @@ -310,9 +424,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { ArrayExpressionInner::Value( res.into_iter() .map(|v| BooleanExpression::Value(v).into()) - .collect(), + .collect::>() + .into(), ) - .annotate(Type::Boolean, bitwidth.to_usize()) + .annotate(Type::Boolean, bitwidth.to_usize() as u32) .into() } _ => unreachable!("should be a uint value"), @@ -321,40 +436,44 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { match arguments.iter().all(|a| is_constant(a)) { true => { - let r: Option> = match key.id { - "_U32_FROM_BITS" => Some(process_u_from_bits( + let r: Option> = match embed { + FlatEmbed::U32ToField => None, // todo + FlatEmbed::U32FromBits => Some(process_u_from_bits( assignees.clone(), arguments.clone(), UBitwidth::B32, )), - "_U16_FROM_BITS" => Some(process_u_from_bits( + FlatEmbed::U16FromBits => Some(process_u_from_bits( assignees.clone(), arguments.clone(), UBitwidth::B16, )), - "_U8_FROM_BITS" => Some(process_u_from_bits( + FlatEmbed::U8FromBits => Some(process_u_from_bits( assignees.clone(), arguments.clone(), UBitwidth::B8, )), - "_U32_TO_BITS" => Some(process_u_to_bits( + FlatEmbed::U32ToBits => Some(process_u_to_bits( assignees.clone(), arguments.clone(), UBitwidth::B32, )), - "_U16_TO_BITS" => Some(process_u_to_bits( + FlatEmbed::U16ToBits => Some(process_u_to_bits( assignees.clone(), arguments.clone(), UBitwidth::B16, )), - "_U8_TO_BITS" => Some(process_u_to_bits( + FlatEmbed::U8ToBits => Some(process_u_to_bits( assignees.clone(), arguments.clone(), UBitwidth::B8, )), - "_UNPACK" => { + FlatEmbed::Unpack => { assert_eq!(assignees.len(), 1); assert_eq!(arguments.len(), 1); + assert_eq!(generics.len(), 1); + + let bit_width = generics[0]; match FieldElementExpression::try_from(arguments[0].clone()) .unwrap() @@ -363,7 +482,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { let mut num = num; let mut res = vec![]; - for i in (0..T::get_required_bits()).rev() { + for i in (0..bit_width as usize).rev() { if T::from(2).pow(i) <= num { num = num - T::from(2).pow(i); res.push(true); @@ -379,56 +498,35 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { .map(|v| { BooleanExpression::Value(v).into() }) - .collect(), + .collect::>() + .into(), ) - .annotate(Type::Boolean, T::get_required_bits()) + .annotate(Type::Boolean, bit_width) .into(), ) } _ => unreachable!("should be a field value"), } } - "_SHA256_ROUND" => None, - _ => None, + FlatEmbed::Sha256Round => None, }; match r { // if the function call returns a constant Some(expr) => { - let verbose = self.verbose; - let mut assignees = assignees; match assignees.pop().unwrap() { - TypedAssignee::Identifier(var) => match verbose { - true => { - self.constants - .insert(var.id.clone(), expr.clone()); - vec![TypedStatement::Definition( - TypedAssignee::Identifier(var), - expr, - )] - } - false => { - self.constants.insert(var.id, expr); - - vec![] - } - }, + TypedAssignee::Identifier(var) => { + self.constants.insert(var.id, expr); + vec![] + } assignee => { match self.try_get_constant_mut(&assignee) { - Ok((_, c)) => match verbose { - true => { - *c = expr.clone(); - vec![TypedStatement::Definition( - assignee, expr, - )] - } - false => { - *c = expr; - vec![] - } - }, + Ok((_, c)) => { + *c = expr; + vec![] + } Err(v) => match self.constants.remove(&v.id) { Some(c) => vec![ TypedStatement::Definition( @@ -464,15 +562,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { TypedStatement::Definition(v.clone().into(), c), TypedStatement::MultipleDefinition( vec![assignee], - TypedExpressionList::FunctionCall( - key, arguments, types, + TypedExpressionList::EmbedCall( + embed, generics, arguments, types, ), ), ], None => vec![TypedStatement::MultipleDefinition( vec![assignee], - TypedExpressionList::FunctionCall( - key, arguments, types, + TypedExpressionList::EmbedCall( + embed, generics, arguments, types, ), )], } @@ -483,49 +581,84 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { // if the function arguments are not constant, invalidate the cache // for the return assignees - let invalidations = assignees - .iter() - .flat_map(|assignee| { - let v = self - .try_get_constant_mut(&assignee) - .map(|(v, _)| v) - .unwrap_or_else(|v| v); - match self.constants.remove(&v.id) { - Some(c) => vec![TypedStatement::Definition( - v.clone().into(), - c, - )], - None => vec![], - } - }) - .collect::>(); + let def = TypedStatement::MultipleDefinition( + assignees.clone(), + TypedExpressionList::EmbedCall( + embed, generics, arguments, types, + ), + ); - let l = TypedExpressionList::FunctionCall(key, arguments, types); - invalidations - .into_iter() - .chain(std::iter::once(TypedStatement::MultipleDefinition( - assignees, l, - ))) - .collect() + let invalidations = assignees.iter().flat_map(|assignee| { + let v = self + .try_get_constant_mut(&assignee) + .map(|(v, _)| v) + .unwrap_or_else(|v| v); + match self.constants.remove(&v.id) { + Some(c) => { + vec![TypedStatement::Definition(v.clone().into(), c)] + } + None => vec![], + } + }); + + invalidations.chain(std::iter::once(def)).collect() } } } - } + TypedExpressionList::FunctionCall(key, generics, arguments, types) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| self.fold_uint_expression(g)).transpose()) + .collect::>()?; + + let arguments: Vec<_> = arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::>()?; + + let types = types + .into_iter() + .map(|t| self.fold_type(t)) + .collect::>()?; + + // invalidate the cache for the return assignees as this call mutates them + + let def = TypedStatement::MultipleDefinition( + assignees.clone(), + TypedExpressionList::FunctionCall(key, generics, arguments, types), + ); + + let invalidations = assignees.iter().flat_map(|assignee| { + let v = self + .try_get_constant_mut(&assignee) + .map(|(v, _)| v) + .unwrap_or_else(|v| v); + match self.constants.remove(&v.id) { + Some(c) => { + vec![TypedStatement::Definition(v.clone().into(), c)] + } + None => vec![], + } + }); + + invalidations.chain(std::iter::once(def)).collect() + } + }; + + Ok(statements) } - }; - - // In verbose mode, we always return at least a statement - assert!(res.len() > 0 || !self.verbose); - - res + s @ TypedStatement::PushCallLog(..) => Ok(vec![s]), + s @ TypedStatement::PopCallLog => Ok(vec![s]), + s => fold_statement(self, s), + } } fn fold_uint_expression_inner( &mut self, bitwidth: UBitwidth, e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { - match e { + ) -> Result, Error> { + Ok(match e { UExpressionInner::Identifier(id) => match self.constants.get(&id) { Some(e) => match e { TypedExpression::Uint(e) => e.as_inner().clone(), @@ -534,11 +667,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { None => UExpressionInner::Identifier(id), }, UExpressionInner::Add(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - use std::convert::TryInto; UExpressionInner::Value( (v1 + v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), ) @@ -555,16 +687,15 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } }, UExpressionInner::Sub(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - use std::convert::TryInto; UExpressionInner::Value( (v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), ) } - (e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v { + (e, UExpressionInner::Value(v)) => match v { 0 => e, _ => UExpressionInner::Sub( box e.annotate(bitwidth), @@ -575,12 +706,31 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { UExpressionInner::Sub(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) } }, - UExpressionInner::Mult(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + UExpressionInner::FloorSub(box e1, box e2) => match ( + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), + ) { + (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { + UExpressionInner::Value( + v1.saturating_sub(v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + ) + } + (e, UExpressionInner::Value(v)) => match v { + 0 => e, + _ => UExpressionInner::FloorSub( + box e.annotate(bitwidth), + box UExpressionInner::Value(v).annotate(bitwidth), + ), + }, + (e1, e2) => { + UExpressionInner::Sub(box e1.annotate(bitwidth), box e2.annotate(bitwidth)) + } + }, + UExpressionInner::Mult(box e1, box e2) => match ( + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - use std::convert::TryInto; UExpressionInner::Value( (v1 * v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), ) @@ -598,11 +748,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } }, UExpressionInner::Div(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - use std::convert::TryInto; UExpressionInner::Value( (v1 / v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), ) @@ -619,11 +768,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } }, UExpressionInner::Rem(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { - use std::convert::TryInto; UExpressionInner::Value( (v1 % v2) % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), ) @@ -640,8 +788,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } }, UExpressionInner::RightShift(box e, box by) => { - let e = self.fold_uint_expression(e); - let by = self.fold_uint_expression(by); + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; match (e.into_inner(), by.into_inner()) { (UExpressionInner::Value(v), UExpressionInner::Value(by)) => { UExpressionInner::Value(v >> by) @@ -653,8 +801,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } UExpressionInner::LeftShift(box e, box by) => { - let e = self.fold_uint_expression(e); - let by = self.fold_uint_expression(by); + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; match (e.into_inner(), by.into_inner()) { (UExpressionInner::Value(v), UExpressionInner::Value(by)) => { UExpressionInner::Value((v << by) & (2_u128.pow(bitwidth as u32) - 1)) @@ -666,8 +814,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } UExpressionInner::Xor(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { UExpressionInner::Value(v1 ^ v2) @@ -683,8 +831,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } }, UExpressionInner::And(box e1, box e2) => match ( - self.fold_uint_expression(e1).into_inner(), - self.fold_uint_expression(e2).into_inner(), + self.fold_uint_expression(e1)?.into_inner(), + self.fold_uint_expression(e2)?.into_inner(), ) { (UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => { UExpressionInner::Value(v1 & v2) @@ -697,16 +845,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } }, UExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_uint_expression(consequence); - let alternative = self.fold_uint_expression(alternative); - match self.fold_boolean_expression(condition) { + let consequence = self.fold_uint_expression(consequence)?; + let alternative = self.fold_uint_expression(alternative)?; + match self.fold_boolean_expression(condition)? { BooleanExpression::Value(true) => consequence.into_inner(), BooleanExpression::Value(false) => alternative.into_inner(), c => UExpressionInner::IfElse(box c, box consequence, box alternative), } } UExpressionInner::Not(box e) => { - let e = self.fold_uint_expression(e).into_inner(); + let e = self.fold_uint_expression(e)?.into_inner(); match e { UExpressionInner::Value(v) => { UExpressionInner::Value((!v) & (2_u128.pow(bitwidth as u32) - 1)) @@ -715,88 +863,95 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } UExpressionInner::Neg(box e) => { - let e = self.fold_uint_expression(e).into_inner(); + let e = self.fold_uint_expression(e)?.into_inner(); match e { - UExpressionInner::Value(v) => { - use std::convert::TryInto; - UExpressionInner::Value( - (0u128.wrapping_sub(v)) - % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), - ) - } + UExpressionInner::Value(v) => UExpressionInner::Value( + (0u128.wrapping_sub(v)) + % 2_u128.pow(bitwidth.to_usize().try_into().unwrap()), + ), e => UExpressionInner::Neg(box e.annotate(bitwidth)), } } UExpressionInner::Pos(box e) => { - let e = self.fold_uint_expression(e).into_inner(); + let e = self.fold_uint_expression(e)?.into_inner(); match e { UExpressionInner::Value(v) => UExpressionInner::Value(v), e => UExpressionInner::Pos(box e.annotate(bitwidth)), } } UExpressionInner::Select(box array, box index) => { - let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let array = self.fold_array_expression(array)?; + let index = self.fold_uint_expression(index)?; let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - UExpression::try_from(v[n_as_usize].clone()) - .unwrap() - .into_inner() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); - } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&id) { - Some(a) => match a { - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => UExpression::try_from( - v[n.to_dec_string().parse::().unwrap()].clone(), + match size.into_inner() { + UExpressionInner::Value(size) => { + match (array.into_inner(), index.into_inner()) { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + UExpression::try_from( + v.expression_at::>(n as usize) + .unwrap() + .clone(), ) .unwrap() - .into_inner(), - _ => unreachable!(), - }, - _ => unreachable!(""), - }, - None => UExpressionInner::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), + .into_inner() + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } + } + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&id) { + Some(a) => match a { + TypedExpression::Array(a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + UExpression::try_from( + TypedExpression::try_from( + v.0[n as usize].clone(), + ) + .unwrap(), + ) + .unwrap() + .into_inner() + } + _ => unreachable!(), + }, + _ => unreachable!(""), + }, + None => UExpressionInner::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } + } + (a, i) => UExpressionInner::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), ), } } - (a, i) => UExpressionInner::Select(box a.annotate(inner_type, size), box i), + _ => fold_uint_expression_inner( + self, + bitwidth, + UExpressionInner::Select(box array, box index), + )?, } } - UExpressionInner::FunctionCall(key, arguments) => { - assert!( - self.verbose, - "function calls should only exist out of multidef in verbose mode" - ); - fold_uint_expression_inner( - self, - bitwidth, - UExpressionInner::FunctionCall(key, arguments), - ) - } - e => fold_uint_expression_inner(self, bitwidth, e), - } + e => fold_uint_expression_inner(self, bitwidth, e)?, + }) } fn fold_field_expression( &mut self, e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { - match e { + ) -> Result, Error> { + Ok(match e { FieldElementExpression::Identifier(id) => match self.constants.get(&id) { Some(e) => match e { TypedExpression::FieldElement(e) => e.clone(), @@ -807,8 +962,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { None => FieldElementExpression::Identifier(id), }, FieldElementExpression::Add(box e1, box e2) => match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 + n2) @@ -816,8 +971,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { (e1, e2) => FieldElementExpression::Add(box e1, box e2), }, FieldElementExpression::Sub(box e1, box e2) => match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 - n2) @@ -825,8 +980,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { (e1, e2) => FieldElementExpression::Sub(box e1, box e2), }, FieldElementExpression::Mult(box e1, box e2) => match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 * n2) @@ -834,96 +989,116 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { (e1, e2) => FieldElementExpression::Mult(box e1, box e2), }, FieldElementExpression::Div(box e1, box e2) => match ( - self.fold_field_expression(e1), - self.fold_field_expression(e2), + self.fold_field_expression(e1)?, + self.fold_field_expression(e2)?, ) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { FieldElementExpression::Number(n1 / n2) } (e1, e2) => FieldElementExpression::Div(box e1, box e2), }, - FieldElementExpression::Neg(box e) => match self.fold_field_expression(e) { + FieldElementExpression::Neg(box e) => match self.fold_field_expression(e)? { FieldElementExpression::Number(n) => FieldElementExpression::Number(T::zero() - n), e => FieldElementExpression::Neg(box e), }, - FieldElementExpression::Pos(box e) => match self.fold_field_expression(e) { + FieldElementExpression::Pos(box e) => match self.fold_field_expression(e)? { FieldElementExpression::Number(n) => FieldElementExpression::Number(n), e => FieldElementExpression::Pos(box e), }, FieldElementExpression::Pow(box e1, box e2) => { - let e1 = self.fold_field_expression(e1); - let e2 = self.fold_field_expression(e2); - match (e1, e2) { - (_, FieldElementExpression::Number(ref n2)) if *n2 == T::from(0) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; + match (e1, e2.into_inner()) { + (_, UExpressionInner::Value(ref n2)) if *n2 == 0 => { FieldElementExpression::Number(T::from(1)) } - (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { - FieldElementExpression::Number(n1.pow(n2)) - } - (e1, FieldElementExpression::Number(n2)) => { - FieldElementExpression::Pow(box e1, box FieldElementExpression::Number(n2)) + (FieldElementExpression::Number(n1), UExpressionInner::Value(n2)) => { + FieldElementExpression::Number(n1.pow(n2 as usize)) } + (e1, UExpressionInner::Value(n2)) => FieldElementExpression::Pow( + box e1, + box UExpressionInner::Value(n2).annotate(UBitwidth::B32), + ), (_, e2) => unreachable!(format!( "non-constant exponent {} detected during static analysis", - e2 + e2.annotate(UBitwidth::B32) )), } } FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_field_expression(consequence); - let alternative = self.fold_field_expression(alternative); - match self.fold_boolean_expression(condition) { + let consequence = self.fold_field_expression(consequence)?; + let alternative = self.fold_field_expression(alternative)?; + match self.fold_boolean_expression(condition)? { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, c => FieldElementExpression::IfElse(box c, box consequence, box alternative), } } FieldElementExpression::Select(box array, box index) => { - let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let array = self.fold_array_expression(array)?; + let index = self.fold_uint_expression(index)?; let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); - } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&id) { - Some(a) => match a { - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => { - FieldElementExpression::try_from( - v[n.to_dec_string().parse::().unwrap()].clone(), + match size.into_inner() { + UExpressionInner::Value(size) => { + match (array.into_inner(), index.into_inner()) { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + FieldElementExpression::try_from( + v.expression_at::>( + n as usize, ) .unwrap() - } - _ => unreachable!(), - }, - _ => unreachable!(""), - }, - None => FieldElementExpression::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), + .clone(), + ) + .unwrap() + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } + } + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&id) { + Some(a) => match a { + TypedExpression::Array(a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + FieldElementExpression::try_from( + TypedExpression::try_from( + v.0[n as usize].clone(), + ) + .unwrap(), + ) + .unwrap() + } + _ => unreachable!(), + }, + _ => unreachable!(""), + }, + None => FieldElementExpression::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } + } + (a, i) => FieldElementExpression::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), ), } } - (a, i) => { - FieldElementExpression::Select(box a.annotate(inner_type, size), box i) - } + _ => fold_field_expression( + self, + FieldElementExpression::Select(box array, box index), + )?, } } FieldElementExpression::Member(box s, m) => { - let s = self.fold_struct_expression(s); + let s = self.fold_struct_expression(s)?; let members = match s.get_type() { Type::Struct(members) => members, @@ -946,24 +1121,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { inner => FieldElementExpression::Member(box inner.annotate(members), m), } } - FieldElementExpression::FunctionCall(key, inputs) => { - assert!( - self.verbose, - "function calls should only exist out of multidef in verbose mode" - ); - fold_field_expression(self, FieldElementExpression::FunctionCall(key, inputs)) - } - e => fold_field_expression(self, e), - } + e => fold_field_expression(self, e)?, + }) } fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, - ) -> ArrayExpressionInner<'ast, T> { - match e { + ) -> Result, Error> { + Ok(match e { ArrayExpressionInner::Identifier(id) => match self.constants.get(&id) { Some(e) => match e { TypedExpression::Array(e) => e.as_inner().clone(), @@ -972,71 +1139,92 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { None => ArrayExpressionInner::Identifier(id), }, ArrayExpressionInner::Select(box array, box index) => { - let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let array = self.fold_array_expression(array)?; + let index = self.fold_uint_expression(index)?; let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - ArrayExpression::try_from(v[n_as_usize].clone()) + match size.into_inner() { + UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) + { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + ArrayExpression::try_from( + v.expression_at::>(n as usize) + .unwrap() + .clone(), + ) .unwrap() .into_inner() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&id) { - Some(a) => match a { - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => ArrayExpression::try_from( - v[n.to_dec_string().parse::().unwrap()].clone(), - ) - .unwrap() - .into_inner(), - _ => unreachable!(), + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&id) { + Some(a) => match a { + TypedExpression::Array(a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + ArrayExpression::try_from( + v.expression_at::>( + n as usize, + ) + .unwrap() + .clone(), + ) + .unwrap() + .into_inner() + } + _ => unreachable!(), + }, + _ => unreachable!(""), }, - _ => unreachable!(""), - }, - None => ArrayExpressionInner::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), - ), + None => ArrayExpressionInner::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } } - } - (a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i), + (a, i) => ArrayExpressionInner::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), + ), + }, + _ => fold_array_expression_inner( + self, + ty, + ArrayExpressionInner::Select(box array, box index), + )?, } } ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_array_expression(consequence); - let alternative = self.fold_array_expression(alternative); - match self.fold_boolean_expression(condition) { + let consequence = self.fold_array_expression(consequence)?; + let alternative = self.fold_array_expression(alternative)?; + match self.fold_boolean_expression(condition)? { BooleanExpression::Value(true) => consequence.into_inner(), BooleanExpression::Value(false) => alternative.into_inner(), c => ArrayExpressionInner::IfElse(box c, box consequence, box alternative), } } - ArrayExpressionInner::Member(box s, m) => { - let s = self.fold_struct_expression(s); + ArrayExpressionInner::Member(box struc, id) => { + let struc = self.fold_struct_expression(struc)?; - let members = match s.get_type() { + let members = match struc.get_type() { Type::Struct(members) => members, _ => unreachable!("should be a struct"), }; - match s.into_inner() { + match struc.into_inner() { StructExpressionInner::Value(v) => { match members .iter() .zip(v) - .find(|(member, _)| member.id == m) + .find(|(member, _)| member.id == id) .unwrap() .1 { @@ -1044,31 +1232,19 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { _ => unreachable!("should be an array"), } } - inner => ArrayExpressionInner::Member(box inner.annotate(members), m), + inner => ArrayExpressionInner::Member(box inner.annotate(members), id), } } - ArrayExpressionInner::FunctionCall(key, inputs) => { - assert!( - self.verbose, - "function calls should only exist out of multidef in verbose mode" - ); - fold_array_expression_inner( - self, - ty, - size, - ArrayExpressionInner::FunctionCall(key, inputs), - ) - } - e => fold_array_expression_inner(self, ty, size, e), - } + e => fold_array_expression_inner(self, ty, e)?, + }) } fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, - ) -> StructExpressionInner<'ast, T> { - match e { + ) -> Result, Error> { + Ok(match e { StructExpressionInner::Identifier(id) => match self.constants.get(&id) { Some(e) => match e { TypedExpression::Struct(e) => e.as_inner().clone(), @@ -1077,61 +1253,80 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { None => StructExpressionInner::Identifier(id), }, StructExpressionInner::Select(box array, box index) => { - let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let array = self.fold_array_expression(array)?; + let index = self.fold_uint_expression(index)?; let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - StructExpression::try_from(v[n_as_usize].clone()) + match size.into_inner() { + UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) + { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + StructExpression::try_from( + v.expression_at::>(n as usize) + .unwrap() + .clone(), + ) .unwrap() .into_inner() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&id) { - Some(a) => match a { - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => StructExpression::try_from( - v[n.to_dec_string().parse::().unwrap()].clone(), - ) - .unwrap() - .into_inner(), - _ => unreachable!(), + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&id) { + Some(a) => match a { + TypedExpression::Array(a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + StructExpression::try_from( + v.expression_at::>( + n as usize, + ) + .unwrap() + .clone(), + ) + .unwrap() + .into_inner() + } + _ => unreachable!(), + }, + _ => unreachable!(""), }, - _ => unreachable!(""), - }, - None => StructExpressionInner::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), - ), + None => StructExpressionInner::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } } - } - (a, i) => { - StructExpressionInner::Select(box a.annotate(inner_type, size), box i) - } + (a, i) => StructExpressionInner::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), + ), + }, + _ => fold_struct_expression_inner( + self, + ty, + StructExpressionInner::Select(box array, box index), + )?, } } StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_struct_expression(consequence); - let alternative = self.fold_struct_expression(alternative); - match self.fold_boolean_expression(condition) { + let consequence = self.fold_struct_expression(consequence)?; + let alternative = self.fold_struct_expression(alternative)?; + match self.fold_boolean_expression(condition)? { BooleanExpression::Value(true) => consequence.into_inner(), BooleanExpression::Value(false) => alternative.into_inner(), c => StructExpressionInner::IfElse(box c, box consequence, box alternative), } } StructExpressionInner::Member(box s, m) => { - let s = self.fold_struct_expression(s); + let s = self.fold_struct_expression(s)?; let members = match s.get_type() { Type::Struct(members) => members, @@ -1154,31 +1349,20 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { inner => StructExpressionInner::Member(box inner.annotate(members), m), } } - StructExpressionInner::FunctionCall(key, inputs) => { - assert!( - self.verbose, - "function calls should only exist out of multidef in verbose mode" - ); - fold_struct_expression_inner( - self, - ty, - StructExpressionInner::FunctionCall(key, inputs), - ) - } - e => fold_struct_expression_inner(self, ty, e), - } + e => fold_struct_expression_inner(self, ty, e)?, + }) } fn fold_boolean_expression( &mut self, e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { + ) -> Result, Error> { // Note: we only propagate when we see constants, as comparing of arbitrary expressions would lead to // a lot of false negatives due to expressions not being in a canonical form // For example, `2 * a` is equivalent to `a + a`, but our notion of equality would not detect that here // These kind of reduction rules are easier to apply later in the process, when we have canonical representations // of expressions, ie `a + a` would always be written `2 * a` - match e { + Ok(match e { BooleanExpression::Identifier(id) => match self.constants.get(&id) { Some(e) => match e { TypedExpression::Boolean(e) => e.clone(), @@ -1187,8 +1371,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { None => BooleanExpression::Identifier(id), }, BooleanExpression::FieldEq(box e1, box e2) => { - let e1 = self.fold_field_expression(e1); - let e2 = self.fold_field_expression(e2); + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; match (e1, e2) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { @@ -1198,8 +1382,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } BooleanExpression::UintEq(box e1, box e2) => { - let e1 = self.fold_uint_expression(e1); - let e2 = self.fold_uint_expression(e2); + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; match (e1.as_inner(), e2.as_inner()) { (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { @@ -1209,8 +1393,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } BooleanExpression::BoolEq(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1); - let e2 = self.fold_boolean_expression(e2); + let e1 = self.fold_boolean_expression(e1)?; + let e2 = self.fold_boolean_expression(e2)?; match (e1, e2) { (BooleanExpression::Value(n1), BooleanExpression::Value(n2)) => { @@ -1219,53 +1403,115 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { (e1, e2) => BooleanExpression::BoolEq(box e1, box e2), } } - BooleanExpression::Lt(box e1, box e2) => { - let e1 = self.fold_field_expression(e1); - let e2 = self.fold_field_expression(e2); + BooleanExpression::ArrayEq(box e1, box e2) => { + let e1 = self.fold_array_expression(e1)?; + let e2 = self.fold_array_expression(e2)?; + + if let (Ok(t1), Ok(t2)) = ( + ConcreteType::try_from(e1.get_type()), + ConcreteType::try_from(e2.get_type()), + ) { + if t1 != t2 { + return Err(Error::Type(format!( + "Cannot compare {} of type {} to {} of type {}", + e1, t1, e2, t2 + ))); + } + }; + + BooleanExpression::ArrayEq(box e1, box e2) + } + BooleanExpression::FieldLt(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; match (e1, e2) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 < n2) } - (e1, e2) => BooleanExpression::Lt(box e1, box e2), + (e1, e2) => BooleanExpression::FieldLt(box e1, box e2), } } - BooleanExpression::Le(box e1, box e2) => { - let e1 = self.fold_field_expression(e1); - let e2 = self.fold_field_expression(e2); + BooleanExpression::FieldLe(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; match (e1, e2) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 <= n2) } - (e1, e2) => BooleanExpression::Le(box e1, box e2), + (e1, e2) => BooleanExpression::FieldLe(box e1, box e2), } } - BooleanExpression::Gt(box e1, box e2) => { - let e1 = self.fold_field_expression(e1); - let e2 = self.fold_field_expression(e2); + BooleanExpression::FieldGt(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; match (e1, e2) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 > n2) } - (e1, e2) => BooleanExpression::Gt(box e1, box e2), + (e1, e2) => BooleanExpression::FieldGt(box e1, box e2), } } - BooleanExpression::Ge(box e1, box e2) => { - let e1 = self.fold_field_expression(e1); - let e2 = self.fold_field_expression(e2); + BooleanExpression::FieldGe(box e1, box e2) => { + let e1 = self.fold_field_expression(e1)?; + let e2 = self.fold_field_expression(e2)?; match (e1, e2) { (FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => { BooleanExpression::Value(n1 >= n2) } - (e1, e2) => BooleanExpression::Ge(box e1, box e2), + (e1, e2) => BooleanExpression::FieldGe(box e1, box e2), + } + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; + + match (e1.as_inner(), e2.as_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + BooleanExpression::Value(n1 < n2) + } + _ => BooleanExpression::UintLt(box e1, box e2), + } + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; + + match (e1.as_inner(), e2.as_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + BooleanExpression::Value(n1 <= n2) + } + _ => BooleanExpression::UintLe(box e1, box e2), + } + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; + + match (e1.as_inner(), e2.as_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + BooleanExpression::Value(n1 > n2) + } + _ => BooleanExpression::UintGt(box e1, box e2), + } + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = self.fold_uint_expression(e1)?; + let e2 = self.fold_uint_expression(e2)?; + + match (e1.as_inner(), e2.as_inner()) { + (UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => { + BooleanExpression::Value(n1 >= n2) + } + _ => BooleanExpression::UintGe(box e1, box e2), } } BooleanExpression::Or(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1); - let e2 = self.fold_boolean_expression(e2); + let e1 = self.fold_boolean_expression(e1)?; + let e2 = self.fold_boolean_expression(e2)?; match (e1, e2) { // reduction of constants @@ -1284,8 +1530,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } BooleanExpression::And(box e1, box e2) => { - let e1 = self.fold_boolean_expression(e1); - let e2 = self.fold_boolean_expression(e2); + let e1 = self.fold_boolean_expression(e1)?; + let e2 = self.fold_boolean_expression(e2)?; match (e1, e2) { // reduction of constants @@ -1302,63 +1548,81 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { } } BooleanExpression::Not(box e) => { - let e = self.fold_boolean_expression(e); + let e = self.fold_boolean_expression(e)?; match e { BooleanExpression::Value(v) => BooleanExpression::Value(!v), e => BooleanExpression::Not(box e), } } BooleanExpression::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_boolean_expression(consequence); - let alternative = self.fold_boolean_expression(alternative); - match self.fold_boolean_expression(condition) { + let consequence = self.fold_boolean_expression(consequence)?; + let alternative = self.fold_boolean_expression(alternative)?; + match self.fold_boolean_expression(condition)? { BooleanExpression::Value(true) => consequence, BooleanExpression::Value(false) => alternative, c => BooleanExpression::IfElse(box c, box consequence, box alternative), } } BooleanExpression::Select(box array, box index) => { - let array = self.fold_array_expression(array); - let index = self.fold_field_expression(index); + let array = self.fold_array_expression(array)?; + let index = self.fold_uint_expression(index)?; let inner_type = array.inner_type().clone(); let size = array.size(); - match (array.into_inner(), index) { - (ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => { - let n_as_usize = n.to_dec_string().parse::().unwrap(); - if n_as_usize < size { - BooleanExpression::try_from(v[n_as_usize].clone()).unwrap() - } else { - unreachable!( - "out of bounds index ({} >= {}) found during static analysis", - n_as_usize, size - ); + match size.into_inner() { + UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) + { + (ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => { + if n < size { + BooleanExpression::try_from( + v.expression_at::>(n as usize) + .unwrap() + .clone(), + ) + .unwrap() + } else { + unreachable!( + "out of bounds index ({} >= {}) found during static analysis", + n, size + ); + } } - } - (ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => { - match self.constants.get(&id) { - Some(a) => match a { - TypedExpression::Array(a) => match a.as_inner() { - ArrayExpressionInner::Value(v) => BooleanExpression::try_from( - v[n.to_dec_string().parse::().unwrap()].clone(), - ) - .unwrap(), - _ => unreachable!(), + (ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => { + match self.constants.get(&id) { + Some(a) => match a { + TypedExpression::Array(a) => match a.as_inner() { + ArrayExpressionInner::Value(v) => { + BooleanExpression::try_from( + TypedExpression::try_from(v.0[n as usize].clone()) + .unwrap(), + ) + .unwrap() + } + _ => unreachable!(), + }, + _ => unreachable!(""), }, - _ => unreachable!(""), - }, - None => BooleanExpression::Select( - box ArrayExpressionInner::Identifier(id).annotate(inner_type, size), - box FieldElementExpression::Number(n), - ), + None => BooleanExpression::Select( + box ArrayExpressionInner::Identifier(id) + .annotate(inner_type, size as u32), + box UExpressionInner::Value(n).annotate(UBitwidth::B32), + ), + } } - } - (a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i), + (a, i) => BooleanExpression::Select( + box a.annotate(inner_type, size as u32), + box i.annotate(UBitwidth::B32), + ), + }, + _ => fold_boolean_expression( + self, + BooleanExpression::Select(box array, box index), + )?, } } BooleanExpression::Member(box s, m) => { - let s = self.fold_struct_expression(s); + let s = self.fold_struct_expression(s)?; let members = match s.get_type() { Type::Struct(members) => members, @@ -1381,15 +1645,8 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> { inner => BooleanExpression::Member(box inner.annotate(members), m), } } - BooleanExpression::FunctionCall(key, inputs) => { - assert!( - self.verbose, - "function calls should only exist out of multidef in verbose mode" - ); - fold_boolean_expression(self, BooleanExpression::FunctionCall(key, inputs)) - } - e => fold_boolean_expression(self, e), - } + e => fold_boolean_expression(self, e)?, + }) } } @@ -1414,8 +1671,8 @@ mod tests { ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(5)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(5))) ); } @@ -1427,8 +1684,8 @@ mod tests { ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(1)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); } @@ -1440,8 +1697,8 @@ mod tests { ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(6)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(6))) ); } @@ -1453,8 +1710,8 @@ mod tests { ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(3)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1462,12 +1719,12 @@ mod tests { fn pow() { let e = FieldElementExpression::Pow( box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + box 3u32.into(), ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(8)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(8))) ); } @@ -1480,8 +1737,8 @@ mod tests { ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(2)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(2))) ); } @@ -1494,29 +1751,30 @@ mod tests { ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(3)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } #[test] fn select() { let e = FieldElementExpression::Select( - box ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ]) - .annotate(Type::FieldElement, 3), - box FieldElementExpression::Add( - box FieldElementExpression::Number(Bn128Field::from(1)), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), + box ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(3)).into(), + ] + .into(), + ) + .annotate(Type::FieldElement, 3usize), + box UExpressionInner::Add(box 1u32.into(), box 1u32.into()) + .annotate(UBitwidth::B32), ); assert_eq!( - Propagator::new().fold_field_expression(e), - FieldElementExpression::Number(Bn128Field::from(3)) + Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } } @@ -1537,16 +1795,19 @@ mod tests { BooleanExpression::Not(box BooleanExpression::Identifier("a".into())); assert_eq!( - Propagator::new().fold_boolean_expression(e_true), - BooleanExpression::Value(true) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_true), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_false), - BooleanExpression::Value(false) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_false), + Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_default.clone()), - e_default + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_default.clone()), + Ok(e_default) ); } @@ -1563,143 +1824,149 @@ mod tests { ); assert_eq!( - Propagator::new().fold_boolean_expression(e_true), - BooleanExpression::Value(true) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_true), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_false), - BooleanExpression::Value(false) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_false), + Ok(BooleanExpression::Value(false)) ); } #[test] fn bool_eq() { assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::BoolEq( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::BoolEq( box BooleanExpression::Value(false), box BooleanExpression::Value(false) - ) - ), - BooleanExpression::Value(true) + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::BoolEq( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::BoolEq( box BooleanExpression::Value(true), box BooleanExpression::Value(true) - ) - ), - BooleanExpression::Value(true) + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::BoolEq( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::BoolEq( box BooleanExpression::Value(true), box BooleanExpression::Value(false) - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::BoolEq( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::BoolEq( box BooleanExpression::Value(false), box BooleanExpression::Value(true) - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); } #[test] fn lt() { - let e_true = BooleanExpression::Lt( + let e_true = BooleanExpression::FieldLt( box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(4)), ); - let e_false = BooleanExpression::Lt( + let e_false = BooleanExpression::FieldLt( box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(2)), ); assert_eq!( - Propagator::new().fold_boolean_expression(e_true), - BooleanExpression::Value(true) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_true), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_false), - BooleanExpression::Value(false) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_false), + Ok(BooleanExpression::Value(false)) ); } #[test] fn le() { - let e_true = BooleanExpression::Le( + let e_true = BooleanExpression::FieldLe( box FieldElementExpression::Number(Bn128Field::from(2)), box FieldElementExpression::Number(Bn128Field::from(2)), ); - let e_false = BooleanExpression::Le( + let e_false = BooleanExpression::FieldLe( box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(2)), ); assert_eq!( - Propagator::new().fold_boolean_expression(e_true), - BooleanExpression::Value(true) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_true), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_false), - BooleanExpression::Value(false) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_false), + Ok(BooleanExpression::Value(false)) ); } #[test] fn gt() { - let e_true = BooleanExpression::Gt( + let e_true = BooleanExpression::FieldGt( box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(4)), ); - let e_false = BooleanExpression::Gt( + let e_false = BooleanExpression::FieldGt( box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(5)), ); assert_eq!( - Propagator::new().fold_boolean_expression(e_true), - BooleanExpression::Value(true) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_true), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_false), - BooleanExpression::Value(false) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_false), + Ok(BooleanExpression::Value(false)) ); } #[test] fn ge() { - let e_true = BooleanExpression::Ge( + let e_true = BooleanExpression::FieldGe( box FieldElementExpression::Number(Bn128Field::from(5)), box FieldElementExpression::Number(Bn128Field::from(5)), ); - let e_false = BooleanExpression::Ge( + let e_false = BooleanExpression::FieldGe( box FieldElementExpression::Number(Bn128Field::from(4)), box FieldElementExpression::Number(Bn128Field::from(5)), ); assert_eq!( - Propagator::new().fold_boolean_expression(e_true), - BooleanExpression::Value(true) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_true), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::new().fold_boolean_expression(e_false), - BooleanExpression::Value(false) + Propagator::with_constants(&mut Constants::new()) + .fold_boolean_expression(e_false), + Ok(BooleanExpression::Value(false)) ); } @@ -1708,76 +1975,68 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Identifier(a_bool.clone()) - ) - ), - BooleanExpression::Identifier(a_bool.clone()) + )), + Ok(BooleanExpression::Identifier(a_bool.clone())) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Identifier(a_bool.clone()), box BooleanExpression::Value(true), - ) - ), - BooleanExpression::Identifier(a_bool.clone()) + )), + Ok(BooleanExpression::Identifier(a_bool.clone())) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Identifier(a_bool.clone()) - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Identifier(a_bool.clone()), box BooleanExpression::Value(false), - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - ) - ), - BooleanExpression::Value(true) + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression( - BooleanExpression::And( + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - ) - ), - BooleanExpression::Value(false) + )), + Ok(BooleanExpression::Value(false)) ); } @@ -1786,60 +2045,68 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Identifier(a_bool.clone()) - )), - BooleanExpression::Value(true) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(true), + box BooleanExpression::Identifier(a_bool.clone()) + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Identifier(a_bool.clone()), - box BooleanExpression::Value(true), - )), - BooleanExpression::Value(true) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Identifier(a_bool.clone()), + box BooleanExpression::Value(true), + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Identifier(a_bool.clone()) - )), - BooleanExpression::Identifier(a_bool.clone()) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(false), + box BooleanExpression::Identifier(a_bool.clone()) + )), + Ok(BooleanExpression::Identifier(a_bool.clone())) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Identifier(a_bool.clone()), - box BooleanExpression::Value(false), - )), - BooleanExpression::Identifier(a_bool.clone()) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Identifier(a_bool.clone()), + box BooleanExpression::Value(false), + )), + Ok(BooleanExpression::Identifier(a_bool.clone())) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(false), - )), - BooleanExpression::Value(true) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(true), + box BooleanExpression::Value(false), + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Value(true), - )), - BooleanExpression::Value(true) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(false), + box BooleanExpression::Value(true), + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(true), - box BooleanExpression::Value(true), - )), - BooleanExpression::Value(true) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(true), + box BooleanExpression::Value(true), + )), + Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::new().fold_boolean_expression(BooleanExpression::Or( - box BooleanExpression::Value(false), - box BooleanExpression::Value(false), - )), - BooleanExpression::Value(false) + Propagator::::with_constants(&mut Constants::new()) + .fold_boolean_expression(BooleanExpression::Or( + box BooleanExpression::Value(false), + box BooleanExpression::Value(false), + )), + Ok(BooleanExpression::Value(false)) ); } } diff --git a/zokrates_core/src/static_analysis/redefinition.rs b/zokrates_core/src/static_analysis/redefinition.rs index dbcd8602..f44a8450 100644 --- a/zokrates_core/src/static_analysis/redefinition.rs +++ b/zokrates_core/src/static_analysis/redefinition.rs @@ -68,6 +68,6 @@ impl<'ast, T: Field> Folder<'ast, T> for RedefinitionOptimizer<'ast> { } fn fold_name(&mut self, s: Identifier<'ast>) -> Identifier<'ast> { - self.identifiers.get(&s).map(|r| r.clone()).unwrap_or(s) + self.identifiers.get(&s).cloned().unwrap_or(s) } } diff --git a/zokrates_core/src/static_analysis/reducer/inline.rs b/zokrates_core/src/static_analysis/reducer/inline.rs new file mode 100644 index 00000000..01aff4db --- /dev/null +++ b/zokrates_core/src/static_analysis/reducer/inline.rs @@ -0,0 +1,235 @@ +// The inlining phase takes a call site (fun::(args)) and inlines it: + +// Given: +// ``` +// def foo(field a) -> field: +// a = a +// n = n +// return a +// ``` +// +// The call site +// ``` +// foo::<42>(x) +// ``` +// +// Becomes +// ``` +// # Call foo::<42> with a_0 := x +// n_0 = 42 +// a_1 = a_0 +// n_1 = n_0 +// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1 + +// Notes: +// - The body of the function is in SSA form +// - The return value(s) are assigned to internal variables + +use crate::embed::FlatEmbed; +use crate::static_analysis::reducer::Output; +use crate::static_analysis::reducer::ShallowTransformer; +use crate::static_analysis::reducer::Versions; +use crate::typed_absy::types::ConcreteGenericsAssignment; +use crate::typed_absy::CoreIdentifier; +use crate::typed_absy::Identifier; +use crate::typed_absy::TypedAssignee; +use crate::typed_absy::{ + ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Signature, + Type, TypedExpression, TypedFunctionSymbol, TypedProgram, TypedStatement, UExpression, + UExpressionInner, Variable, +}; +use zokrates_field::Field; + +pub enum InlineError<'ast, T> { + Generic(DeclarationFunctionKey<'ast>, ConcreteFunctionKey<'ast>), + Flat( + FlatEmbed, + Vec, + Vec>, + Vec>, + ), + NonConstant( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, + Vec>, + ), +} + +fn get_canonical_function<'ast, T: Field>( + function_key: DeclarationFunctionKey<'ast>, + program: &TypedProgram<'ast, T>, +) -> (DeclarationFunctionKey<'ast>, TypedFunctionSymbol<'ast, T>) { + match program + .modules + .get(&function_key.module) + .unwrap() + .functions + .iter() + .find(|(key, _)| function_key == **key) + .unwrap() + { + (_, TypedFunctionSymbol::There(key)) => get_canonical_function(key.clone(), &program), + (key, s) => (key.clone(), s.clone()), + } +} + +type InlineResult<'ast, T> = Result< + Output<(Vec>, Vec>), Vec>>, + InlineError<'ast, T>, +>; + +pub fn inline_call<'a, 'ast, T: Field>( + k: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_types: Vec>, + program: &TypedProgram<'ast, T>, + versions: &'a mut Versions<'ast>, +) -> InlineResult<'ast, T> { + use std::convert::TryFrom; + + use crate::typed_absy::Typed; + + // we try to get concrete values for explicit generics + let generics_values: Vec> = generics + .iter() + .map(|g| { + g.as_ref() + .map(|g| match g.as_inner() { + UExpressionInner::Value(v) => Ok(*v as u32), + _ => Err(()), + }) + .transpose() + }) + .collect::>() + .map_err(|_| { + InlineError::NonConstant( + k.clone(), + generics.clone(), + arguments.clone(), + output_types.clone(), + ) + })?; + + // we infer a signature based on inputs and outputs + // this is where we could handle explicit annotations + let inferred_signature = Signature::new() + .generics(generics.clone()) + .inputs(arguments.iter().map(|a| a.get_type()).collect()) + .outputs(output_types.clone()); + + // we try to get concrete values for the whole signature. if this fails we should propagate again + let inferred_signature = match ConcreteSignature::try_from(inferred_signature) { + Ok(s) => s, + Err(_) => { + return Err(InlineError::NonConstant( + k, + generics, + arguments, + output_types, + )); + } + }; + + let (decl_key, symbol) = get_canonical_function(k.clone(), program); + + // get an assignment of generics for this call site + let assignment: ConcreteGenericsAssignment<'ast> = k + .signature + .specialize(generics_values, &inferred_signature) + .map_err(|_| { + InlineError::Generic( + k.clone(), + ConcreteFunctionKey { + module: decl_key.module.clone(), + id: decl_key.id, + signature: inferred_signature.clone(), + }, + ) + })?; + + let f = match symbol { + TypedFunctionSymbol::Here(f) => Ok(f), + TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat( + e, + e.generics(&assignment), + arguments.clone(), + output_types, + )), + _ => unreachable!(), + }?; + + assert_eq!(f.arguments.len(), arguments.len()); + + let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) { + Output::Complete(v) => (v, None), + Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)), + }; + + let call_log = TypedStatement::PushCallLog(decl_key.clone(), assignment.clone()); + + let input_bindings: Vec> = ssa_f + .arguments + .into_iter() + .zip(inferred_signature.inputs.clone()) + .map(|(p, t)| ConcreteVariable::with_id_and_type(p.id.id, t)) + .zip(arguments.clone()) + .map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a)) + .collect(); + + let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f + .statements + .into_iter() + .partition(|s| !matches!(s, TypedStatement::Return(..))); + + assert_eq!(returns.len(), 1); + + let returns = match returns.pop().unwrap() { + TypedStatement::Return(e) => e, + _ => unreachable!(), + }; + + let res: Vec> = inferred_signature + .outputs + .iter() + .enumerate() + .map(|(i, t)| { + ConcreteVariable::with_id_and_type( + Identifier::from(CoreIdentifier::Call(i)).version( + *versions + .entry(CoreIdentifier::Call(i).clone()) + .and_modify(|e| *e += 1) // if it was already declared, we increment + .or_insert(0), + ), + t.clone(), + ) + }) + .collect(); + + let expressions: Vec> = res + .iter() + .map(|v| TypedExpression::from(Variable::from(v.clone()))) + .collect(); + + assert_eq!(res.len(), returns.len()); + + let output_bindings: Vec> = res + .into_iter() + .zip(returns) + .map(|(v, a)| TypedStatement::Definition(TypedAssignee::Identifier(v.into()), a)) + .collect(); + + let pop_log = TypedStatement::PopCallLog; + + let statements: Vec<_> = std::iter::once(call_log) + .chain(input_bindings) + .chain(statements) + .chain(output_bindings) + .chain(std::iter::once(pop_log)) + .collect(); + + Ok(incomplete_data + .map(|d| Output::Incomplete((statements.clone(), expressions.clone()), d)) + .unwrap_or_else(|| Output::Complete((statements, expressions)))) +} diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs new file mode 100644 index 00000000..7987a1d6 --- /dev/null +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -0,0 +1,1635 @@ +// The reducer reduces the program to a single function which is: +// - in SSA form +// - free of function calls (except for low level calls) thanks to inlining +// - free of for-loops thanks to unrolling + +// The process happens in two steps +// 1. Shallow SSA for the `main` function +// We turn the `main` function into SSA form, but ignoring function calls and for loops +// 2. Unroll and inline +// We go through the shallow-SSA program and +// - unroll loops +// - inline function calls. This includes applying shallow-ssa on the target function + +mod inline; +mod shallow_ssa; + +use self::inline::{inline_call, InlineError}; +use crate::typed_absy::result_folder::*; +use crate::typed_absy::types::ConcreteGenericsAssignment; +use crate::typed_absy::types::GGenericsAssignment; +use crate::typed_absy::Folder; +use std::collections::HashMap; + +use crate::typed_absy::{ + ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier, + DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression, + StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction, + TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, + Variable, +}; + +use std::convert::{TryFrom, TryInto}; + +use zokrates_field::Field; + +use self::shallow_ssa::ShallowTransformer; + +use crate::static_analysis::Propagator; + +use std::fmt; + +// An SSA version map, giving access to the latest version number for each identifier +pub type Versions<'ast> = HashMap, usize>; + +// A container to represent whether more treatment must be applied to the function +#[derive(Debug, PartialEq)] +pub enum Output { + Complete(U), + Incomplete(U, V), +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Error { + Incompatible(String), + GenericsInMain, + // TODO: give more details about what's blocking the progress + NoProgress, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Incompatible(s) => write!( + f, + "{}", + s + ), + Error::GenericsInMain => write!(f, "Cannot generate code for generic function"), + Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds") + } + } +} + +#[derive(Debug, Default)] +struct Substitutions<'ast>(HashMap, HashMap>); + +impl<'ast> Substitutions<'ast> { + // create an equivalent substitution map where all paths + // are of length 1 + fn canonicalize(self) -> Self { + Substitutions( + self.0 + .into_iter() + .map(|(id, sub)| (id, Self::canonicalize_sub(sub))) + .collect(), + ) + } + + // canonicalize substitutions for a given id + fn canonicalize_sub(sub: HashMap) -> HashMap { + fn add_to_cache( + sub: &HashMap, + cache: HashMap, + k: usize, + ) -> HashMap { + match cache.contains_key(&k) { + // `k` is already in the cache, no changes to the cache + true => cache, + _ => match sub.get(&k) { + // `k` does not point to anything, no changes to the cache + None => cache, + // `k` points to some `v + Some(v) => { + // add `v` to the cache + let cache = add_to_cache(sub, cache, *v); + // `k` points to what `v` points to, or to `v` + let v = cache.get(v).cloned().unwrap_or(*v); + let mut cache = cache; + cache.insert(k, v); + cache + } + }, + } + } + + sub.keys() + .fold(HashMap::new(), |cache, k| add_to_cache(&sub, cache, *k)) + } +} + +struct Sub<'a, 'ast> { + substitutions: &'a Substitutions<'ast>, +} + +impl<'a, 'ast> Sub<'a, 'ast> { + fn new(substitutions: &'a Substitutions<'ast>) -> Self { + Self { substitutions } + } +} + +impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> { + fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> { + let version = self + .substitutions + .0 + .get(&id.id) + .map(|sub| sub.get(&id.version).cloned().unwrap_or(id.version)) + .unwrap_or(id.version); + id.version(version) + } +} + +fn register<'ast>( + substitutions: &mut Substitutions<'ast>, + substitute: &Versions<'ast>, + with: &Versions<'ast>, +) { + for (id, key, value) in substitute + .iter() + .filter_map(|(id, version)| with.get(&id).clone().map(|to| (id, version, to))) + .filter(|(_, key, value)| key != value) + { + let sub = substitutions.0.entry(id.clone()).or_default(); + + // redirect `k` to `v`, unless `v` is already redirected to `v0`, in which case we redirect to `v0` + + sub.insert(*key, *sub.get(value).unwrap_or(value)); + } +} + +struct Reducer<'ast, 'a, T> { + statement_buffer: Vec>, + for_loop_versions: Vec>, + for_loop_versions_after: Vec>, + program: &'a TypedProgram<'ast, T>, + versions: &'a mut Versions<'ast>, + substitutions: &'a mut Substitutions<'ast>, + complete: bool, +} + +impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { + fn new( + program: &'a TypedProgram<'ast, T>, + versions: &'a mut Versions<'ast>, + substitutions: &'a mut Substitutions<'ast>, + for_loop_versions: Vec>, + ) -> Self { + // we reverse the vector as it's cheaper to `pop` than to take from + // the head + let mut for_loop_versions = for_loop_versions; + + for_loop_versions.reverse(); + + Reducer { + statement_buffer: vec![], + for_loop_versions_after: vec![], + for_loop_versions, + substitutions, + program, + versions, + complete: true, + } + } + + fn fold_function_call( + &mut self, + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_types: Vec>, + ) -> Result + where + E: FunctionCall<'ast, T> + TryFrom, Error = ()> + std::fmt::Debug, + { + let generics = generics + .into_iter() + .map(|g| g.map(|g| self.fold_uint_expression(g)).transpose()) + .collect::>()?; + + let arguments = arguments + .into_iter() + .map(|e| self.fold_expression(e)) + .collect::>()?; + + let res = inline_call( + key.clone(), + generics, + arguments, + output_types, + &self.program, + &mut self.versions, + ); + + match res { + Ok(Output::Complete((statements, expressions))) => { + self.complete &= true; + self.statement_buffer.extend(statements); + Ok(expressions[0].clone().try_into().unwrap()) + } + Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => { + self.complete = false; + self.statement_buffer.extend(statements); + self.for_loop_versions_after.extend(delta_for_loop_versions); + Ok(expressions[0].clone().try_into().unwrap()) + } + Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!( + "Call site `{}` incompatible with declaration `{}`", + conc.to_string(), + decl.to_string() + ))), + Err(InlineError::NonConstant(key, generics, arguments, mut output_types)) => { + self.complete = false; + + Ok(E::function_call( + key, + generics, + arguments, + output_types.pop().unwrap(), + )) + } + Err(InlineError::Flat(embed, generics, arguments, output_types)) => { + let identifier = Identifier::from(CoreIdentifier::Call(0)).version( + *self + .versions + .entry(CoreIdentifier::Call(0).clone()) + .and_modify(|e| *e += 1) // if it was already declared, we increment + .or_insert(0), + ); + let var = Variable::with_id_and_type(identifier, output_types[0].clone()); + + let v = vec![var.clone().into()]; + + self.statement_buffer + .push(TypedStatement::MultipleDefinition( + v, + TypedExpressionList::EmbedCall(embed, generics, arguments, output_types), + )); + Ok(TypedExpression::from(var).try_into().unwrap()) + } + } + } +} + +impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { + type Error = Error; + + fn fold_statement( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { + let res = match s { + TypedStatement::MultipleDefinition( + v, + TypedExpressionList::FunctionCall(key, generics, arguments, output_types), + ) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| self.fold_uint_expression(g)).transpose()) + .collect::>()?; + + let arguments = arguments + .into_iter() + .map(|a| self.fold_expression(a)) + .collect::>()?; + + match inline_call( + key, + generics, + arguments, + output_types, + &self.program, + &mut self.versions, + ) { + Ok(Output::Complete((statements, expressions))) => { + assert_eq!(v.len(), expressions.len()); + + self.complete &= true; + + Ok(statements + .into_iter() + .chain( + v.into_iter() + .zip(expressions) + .map(|(v, e)| TypedStatement::Definition(v, e)), + ) + .collect()) + } + Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => { + assert_eq!(v.len(), expressions.len()); + + self.complete = false; + self.for_loop_versions_after.extend(delta_for_loop_versions); + + Ok(statements + .into_iter() + .chain( + v.into_iter() + .zip(expressions) + .map(|(v, e)| TypedStatement::Definition(v, e)), + ) + .collect()) + } + Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!( + "Call site `{}` incompatible with declaration `{}`", + conc.to_string(), + decl.to_string() + ))), + Err(InlineError::NonConstant(key, generics, arguments, output_types)) => { + self.complete = false; + + Ok(vec![TypedStatement::MultipleDefinition( + v, + TypedExpressionList::FunctionCall( + key, + generics, + arguments, + output_types, + ), + )]) + } + Err(InlineError::Flat(embed, generics, arguments, output_types)) => { + Ok(vec![TypedStatement::MultipleDefinition( + v, + TypedExpressionList::EmbedCall( + embed, + generics, + arguments, + output_types, + ), + )]) + } + } + } + TypedStatement::For(v, from, to, statements) => { + let versions_before = self.for_loop_versions.pop().unwrap(); + + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => { + let mut out_statements = vec![]; + + // get a fresh set of versions for all variables to use as a starting point inside the loop + self.versions.values_mut().for_each(|v| *v += 1); + + // add this set of versions to the substitution, pointing to the versions before the loop + register(&mut self.substitutions, &self.versions, &versions_before); + + // the versions after the loop are found by applying an offset of 2 to the versions before the loop + let versions_after = versions_before + .clone() + .into_iter() + .map(|(k, v)| (k, v + 2)) + .collect(); + + let mut transformer = ShallowTransformer::with_versions(&mut self.versions); + + for index in *from..*to { + let statements: Vec> = + std::iter::once(TypedStatement::Definition( + v.clone().into(), + UExpression::from(index as u32).into(), + )) + .chain(statements.clone().into_iter()) + .map(|s| transformer.fold_statement(s)) + .flatten() + .collect(); + + out_statements.extend(statements); + } + + let backups = transformer.for_loop_backups; + let blocked = transformer.blocked; + + // we know the final versions of the variables after full unrolling of the loop + // the versions after the loop need to point to these, so we add to the substitutions + register(&mut self.substitutions, &versions_after, &self.versions); + + // we may have found new for loops when unrolling this one, which means new backed up versions + // we insert these in our backup list and update our cursor + + self.for_loop_versions_after.extend(backups); + + // if the ssa transform got blocked, the reduction is not complete + self.complete &= !blocked; + + Ok(out_statements) + } + _ => { + let from = self.fold_uint_expression(from)?; + let to = self.fold_uint_expression(to)?; + self.complete = false; + self.for_loop_versions_after.push(versions_before); + Ok(vec![TypedStatement::For(v, from, to, statements)]) + } + } + } + s => fold_statement(self, s), + }; + + res.map(|res| self.statement_buffer.drain(..).chain(res).collect()) + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + BooleanExpression::FunctionCall(key, generics, arguments) => { + self.fold_function_call(key, generics, arguments, vec![Type::Boolean]) + } + e => fold_boolean_expression(self, e), + } + } + + fn fold_uint_expression( + &mut self, + e: UExpression<'ast, T>, + ) -> Result, Self::Error> { + match e.as_inner() { + UExpressionInner::FunctionCall(key, generics, arguments) => self.fold_function_call( + key.clone(), + generics.clone(), + arguments.clone(), + vec![e.get_type()], + ), + _ => fold_uint_expression(self, e), + } + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + FieldElementExpression::FunctionCall(key, generic, arguments) => { + self.fold_function_call(key, generic, arguments, vec![Type::FieldElement]) + } + e => fold_field_expression(self, e), + } + } + + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + ArrayExpressionInner::FunctionCall(key, generics, arguments) => self + .fold_function_call::>( + key.clone(), + generics, + arguments.clone(), + vec![Type::array(ty.clone())], + ) + .map(|e| e.into_inner()), + ArrayExpressionInner::Slice(box array, box from, box to) => { + let array = self.fold_array_expression(array)?; + let from = self.fold_uint_expression(from)?; + let to = self.fold_uint_expression(to)?; + + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { + Ok(ArrayExpressionInner::Slice(box array, box from, box to)) + } + _ => { + self.complete = false; + Ok(ArrayExpressionInner::Slice(box array, box from, box to)) + } + } + } + _ => fold_array_expression_inner(self, &ty, e), + } + } + + fn fold_struct_expression( + &mut self, + e: StructExpression<'ast, T>, + ) -> Result, Self::Error> { + match e.as_inner() { + StructExpressionInner::FunctionCall(key, generics, arguments) => self + .fold_function_call( + key.clone(), + generics.clone(), + arguments.clone(), + vec![e.get_type()], + ), + _ => fold_struct_expression(self, e), + } + } +} + +pub fn reduce_program(p: TypedProgram) -> Result, Error> { + let main_module = p.modules.get(&p.main).unwrap().clone(); + + let (main_key, main_function) = main_module + .functions + .iter() + .find(|(k, _)| k.id == "main") + .unwrap(); + + let main_function = match main_function { + TypedFunctionSymbol::Here(f) => f.clone(), + _ => unreachable!(), + }; + + match main_function.signature.generics.len() { + 0 => { + let main_function = reduce_function(main_function, GGenericsAssignment::default(), &p)?; + + Ok(TypedProgram { + main: p.main.clone(), + modules: vec![( + p.main.clone(), + TypedModule { + functions: vec![( + main_key.clone(), + TypedFunctionSymbol::Here(main_function), + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }) + } + _ => Err(Error::GenericsInMain), + } +} + +fn reduce_function<'ast, T: Field>( + f: TypedFunction<'ast, T>, + generics: ConcreteGenericsAssignment<'ast>, + program: &TypedProgram<'ast, T>, +) -> Result, Error> { + let mut versions = Versions::default(); + + match ShallowTransformer::transform(f, &generics, &mut versions) { + Output::Complete(f) => Ok(f), + Output::Incomplete(new_f, new_for_loop_versions) => { + let mut for_loop_versions = new_for_loop_versions; + + let mut f = new_f; + + let mut substitutions = Substitutions::default(); + + let mut constants: HashMap, TypedExpression<'ast, T>> = HashMap::new(); + + let mut hash = None; + + loop { + let mut reducer = Reducer::new( + &program, + &mut versions, + &mut substitutions, + for_loop_versions, + ); + + let new_f = TypedFunction { + statements: f + .statements + .into_iter() + .map(|s| reducer.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ..f + }; + + assert!(reducer.for_loop_versions.is_empty()); + + match reducer.complete { + true => { + substitutions = substitutions.canonicalize(); + + let new_f = Sub::new(&substitutions).fold_function(new_f); + + let new_f = Propagator::with_constants(&mut constants) + .fold_function(new_f) + .map_err(|e| match e { + crate::static_analysis::propagation::Error::Type(e) => { + Error::Incompatible(e) + } + })?; + + break Ok(new_f); + } + false => { + for_loop_versions = reducer.for_loop_versions_after; + + let new_f = Sub::new(&substitutions).fold_function(new_f); + + f = Propagator::with_constants(&mut constants) + .fold_function(new_f) + .map_err(|e| match e { + crate::static_analysis::propagation::Error::Type(e) => { + Error::Incompatible(e) + } + })?; + + let new_hash = Some(compute_hash(&f)); + + if new_hash == hash { + break Err(Error::NoProgress); + } else { + hash = new_hash + } + } + } + } + } + } +} + +fn compute_hash(f: &TypedFunction) -> u64 { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut s = DefaultHasher::new(); + f.hash(&mut s); + s.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::typed_absy::types::Constant; + use crate::typed_absy::types::DeclarationSignature; + use crate::typed_absy::{ + ArrayExpressionInner, DeclarationFunctionKey, DeclarationType, DeclarationVariable, + FieldElementExpression, Identifier, Select, Type, TypedExpression, TypedExpressionList, + TypedExpressionOrSpread, UBitwidth, UExpressionInner, Variable, + }; + use zokrates_field::Bn128Field; + + #[test] + fn no_generics() { + // def foo(field a) -> field: + // return a + // def main(field a) -> field: + // u32 n = 42 + // n = n + // a = a + // a = foo(a) + // n = n + // return a + + // expected: + // def main(field a_0) -> field: + // a_1 = a_0 + // # PUSH CALL to foo + // a_3 := a_1 // input binding + // #RETURN_AT_INDEX_0_0 := a_3 + // # POP CALL + // a_2 = #RETURN_AT_INDEX_0_0 + // return a_2 + + let foo: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![TypedStatement::Return(vec![ + FieldElementExpression::Identifier("a".into()).into(), + ])], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let main: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(42u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::MultipleDefinition( + vec![Variable::field_element("a").into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + vec![], + vec![FieldElementExpression::Identifier("a".into()).into()], + vec![Type::FieldElement], + ), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Return(vec![FieldElementExpression::Identifier("a".into()).into()]), + ], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let p = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![ + ( + DeclarationFunctionKey::with_location("main", "foo").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(foo), + ), + ( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + let reduced = reduce_program(p); + + let expected_main = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(1)).into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::PushCallLog( + DeclarationFunctionKey::with_location("main", "foo").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + GGenericsAssignment::default(), + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(3)).into(), + FieldElementExpression::Identifier(Identifier::from("a").version(1)).into(), + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0)) + .into(), + FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(), + ), + TypedStatement::PopCallLog, + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(2)).into(), + FieldElementExpression::Identifier( + Identifier::from(CoreIdentifier::Call(0)).version(0), + ) + .into(), + ), + TypedStatement::Return(vec![FieldElementExpression::Identifier( + Identifier::from("a").version(2), + ) + .into()]), + ], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected: TypedProgram = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(reduced.unwrap(), expected); + } + + #[test] + fn with_generics() { + // def foo(field[K] a) -> field[K]: + // return a + // def main(field a) -> field: + // u32 n = 42 + // n = n + // field[1] b = [a] + // b = foo(b) + // n = n + // return a + b[0] + + // expected: + // def main(field a_0) -> field: + // field[1] b_0 = [42] + // # PUSH CALL to foo::<1> + // a_0 = b_0 + // #RETURN_AT_INDEX_0_0 := a_0 + // # POP CALL + // b_1 = #RETURN_AT_INDEX_0_0 + // return a_2 + b_1[0] + + let foo_signature = DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]) + .outputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]); + + let foo: TypedFunction = TypedFunction { + arguments: vec![ + DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(), + ], + statements: vec![TypedStatement::Return(vec![ + ArrayExpressionInner::Identifier("a".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ])], + signature: foo_signature.clone(), + }; + + let main: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(42u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::array("b", Type::FieldElement, 1u32).into(), + ArrayExpressionInner::Value( + vec![FieldElementExpression::Identifier("a".into()).into()].into(), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::MultipleDefinition( + vec![Variable::array("b", Type::FieldElement, 1u32).into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + vec![None], + vec![ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) + .into()], + vec![Type::array((Type::FieldElement, 1u32))], + ), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Return(vec![(FieldElementExpression::Identifier("a".into()) + + FieldElementExpression::select( + ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + 0u32, + )) + .into()]), + ], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let p = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![ + ( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + TypedFunctionSymbol::Here(foo), + ), + ( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + let reduced = reduce_program(p); + + let expected_main = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::array("b", Type::FieldElement, 1u32).into(), + ArrayExpressionInner::Value( + vec![FieldElementExpression::Identifier("a".into()).into()].into(), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::PushCallLog( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + GGenericsAssignment(vec![("K", 1)].into_iter().collect()), + ), + TypedStatement::Definition( + Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + .into(), + ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::Definition( + Variable::array( + Identifier::from(CoreIdentifier::Call(0)).version(0), + Type::FieldElement, + 1u32, + ) + .into(), + ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::PopCallLog, + TypedStatement::Definition( + Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) + .into(), + ArrayExpressionInner::Identifier( + Identifier::from(CoreIdentifier::Call(0)).version(0), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::Return(vec![(FieldElementExpression::Identifier("a".into()) + + FieldElementExpression::select( + ArrayExpressionInner::Identifier(Identifier::from("b").version(1)) + .annotate(Type::FieldElement, 1u32) + .into(), + 0u32, + )) + .into()]), + ], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(reduced.unwrap(), expected); + } + + #[test] + fn with_generics_and_propagation() { + // def foo(field[K] a) -> field[K]: + // return a + // def main(field a) -> field: + // u32 n = 2 + // n = n + // field[n - 1] b = [a] + // b = foo(b) + // n = n + // return a + b[0] + + // expected: + // def main(field a_0) -> field: + // field[1] b_0 = [42] + // # PUSH CALL to foo::<1> + // a_0 = b_0 + // #RETURN_AT_INDEX_0_0 := a_0 + // # POP CALL + // b_1 = #RETURN_AT_INDEX_0_0 + // return a_2 + b_1[0] + + let foo_signature = DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]) + .outputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]); + + let foo: TypedFunction = TypedFunction { + arguments: vec![ + DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(), + ], + statements: vec![TypedStatement::Return(vec![ + ArrayExpressionInner::Identifier("a".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ])], + signature: foo_signature.clone(), + }; + + let main: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(2u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::array( + "b", + Type::FieldElement, + UExpressionInner::Sub( + box UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32), + box 1u32.into(), + ) + .annotate(UBitwidth::B32), + ) + .into(), + ArrayExpressionInner::Value( + vec![FieldElementExpression::Identifier("a".into()).into()].into(), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::MultipleDefinition( + vec![Variable::array("b", Type::FieldElement, 1u32).into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + vec![None], + vec![ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) + .into()], + vec![Type::array((Type::FieldElement, 1u32))], + ), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Return(vec![(FieldElementExpression::Identifier("a".into()) + + FieldElementExpression::select( + ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + 0u32, + )) + .into()]), + ], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let p = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![ + ( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + TypedFunctionSymbol::Here(foo), + ), + ( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(main), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + let reduced = reduce_program(p); + + let expected_main = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::array("b", Type::FieldElement, 1u32).into(), + ArrayExpressionInner::Value( + vec![FieldElementExpression::Identifier("a".into()).into()].into(), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::PushCallLog( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + GGenericsAssignment(vec![("K", 1)].into_iter().collect()), + ), + TypedStatement::Definition( + Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + .into(), + ArrayExpressionInner::Identifier("b".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::Definition( + Variable::array( + Identifier::from(CoreIdentifier::Call(0)).version(0), + Type::FieldElement, + 1u32, + ) + .into(), + ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::PopCallLog, + TypedStatement::Definition( + Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) + .into(), + ArrayExpressionInner::Identifier( + Identifier::from(CoreIdentifier::Call(0)).version(0), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::Return(vec![(FieldElementExpression::Identifier("a".into()) + + FieldElementExpression::select( + ArrayExpressionInner::Identifier(Identifier::from("b").version(1)) + .annotate(Type::FieldElement, 1u32) + .into(), + 0u32, + )) + .into()]), + ], + signature: DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let expected = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(reduced.unwrap(), expected); + } + + #[test] + fn call_in_call() { + // we use a global ssa counter, hence reusing variable names in called functions + // leads to counter increase + + // def bar(field[K] a) -> field[K]: + // return a + + // def foo(field[K] a) -> field[K]: + // field[K] ret = bar([...a, 0])[0..K] + // return ret + + // def main(): + // field[1] b = foo([1]) + // return + + // expected: + // def main(): + // # PUSH CALL to foo::<1> + // # PUSH CALL to bar::<2> + // field[2] a_1 = [...[1]], 0] + // field[2] #RET_0_1 = a_1 + // # POP CALL + // field[1] ret := #RET_0_1[0..1] + // field[1] #RET_0 = ret + // # POP CALL + // field[1] b_0 := #RET_0 + // return + + let foo_signature = DeclarationSignature::new() + .inputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]) + .outputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]) + .generics(vec![Some("K".into())]); + + let foo: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::array( + "a", + DeclarationType::FieldElement, + Constant::Generic("K".into()), + ) + .into()], + statements: vec![ + TypedStatement::Definition( + Variable::array( + "ret", + Type::FieldElement, + UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32), + ) + .into(), + ArrayExpressionInner::Slice( + box ArrayExpressionInner::FunctionCall( + DeclarationFunctionKey::with_location("main", "bar") + .signature(foo_signature.clone()), + vec![None], + vec![ArrayExpressionInner::Value( + vec![ + TypedExpressionOrSpread::Spread( + ArrayExpressionInner::Identifier("a".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + FieldElementExpression::Number(Bn128Field::from(0)).into(), + ] + .into(), + ) + .annotate( + Type::FieldElement, + UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32) + + 1u32.into(), + ) + .into()], + ) + .annotate( + Type::FieldElement, + UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32) + + 1u32.into(), + ), + box 0u32.into(), + box UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32), + ) + .annotate( + Type::FieldElement, + UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32), + ) + .into(), + ), + TypedStatement::Return(vec![ArrayExpressionInner::Identifier("ret".into()) + .annotate( + Type::FieldElement, + UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32), + ) + .into()]), + ], + signature: foo_signature.clone(), + }; + + let bar_signature = foo_signature.clone(); + + let bar: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::array( + "a", + DeclarationType::FieldElement, + Constant::Generic("K".into()), + ) + .into()], + statements: vec![TypedStatement::Return(vec![ + ArrayExpressionInner::Identifier("a".into()) + .annotate( + Type::FieldElement, + UExpressionInner::Identifier("K".into()).annotate(UBitwidth::B32), + ) + .into(), + ])], + signature: bar_signature.clone(), + }; + + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![ + TypedStatement::MultipleDefinition( + vec![Variable::array("b", Type::FieldElement, 1u32).into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + vec![None], + vec![ArrayExpressionInner::Value( + vec![FieldElementExpression::Number(Bn128Field::from(1)).into()].into(), + ) + .annotate(Type::FieldElement, 1u32) + .into()], + vec![Type::array((Type::FieldElement, 1u32))], + ), + ), + TypedStatement::Return(vec![]), + ], + signature: DeclarationSignature::new(), + }; + + let p = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![ + ( + DeclarationFunctionKey::with_location("main", "bar") + .signature(bar_signature.clone()), + TypedFunctionSymbol::Here(bar), + ), + ( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + TypedFunctionSymbol::Here(foo), + ), + ( + DeclarationFunctionKey::with_location("main", "main"), + TypedFunctionSymbol::Here(main), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + let reduced = reduce_program(p); + + let expected_main = TypedFunction { + arguments: vec![], + statements: vec![ + TypedStatement::PushCallLog( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + GGenericsAssignment(vec![("K", 1)].into_iter().collect()), + ), + TypedStatement::PushCallLog( + DeclarationFunctionKey::with_location("main", "bar") + .signature(foo_signature.clone()), + GGenericsAssignment(vec![("K", 2)].into_iter().collect()), + ), + TypedStatement::Definition( + Variable::array(Identifier::from("a").version(1), Type::FieldElement, 2u32) + .into(), + ArrayExpressionInner::Value( + vec![ + TypedExpressionOrSpread::Spread( + ArrayExpressionInner::Value( + vec![TypedExpressionOrSpread::Expression( + FieldElementExpression::Number(Bn128Field::from(1)).into(), + )] + .into(), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + FieldElementExpression::Number(Bn128Field::from(0)).into(), + ] + .into(), + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ), + TypedStatement::Definition( + Variable::array( + Identifier::from(CoreIdentifier::Call(0)).version(1), + Type::FieldElement, + 2u32, + ) + .into(), + ArrayExpressionInner::Identifier(Identifier::from("a").version(1)) + .annotate(Type::FieldElement, 2u32) + .into(), + ), + TypedStatement::PopCallLog, + TypedStatement::Definition( + Variable::array("ret", Type::FieldElement, 1u32).into(), + ArrayExpressionInner::Slice( + box ArrayExpressionInner::Identifier( + Identifier::from(CoreIdentifier::Call(0)).version(1), + ) + .annotate(Type::FieldElement, 2u32) + .into(), + box 0u32.into(), + box 1u32.into(), + ) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::Definition( + Variable::array( + Identifier::from(CoreIdentifier::Call(0)), + Type::FieldElement, + 1u32, + ) + .into(), + ArrayExpressionInner::Identifier("ret".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::PopCallLog, + TypedStatement::Definition( + Variable::array("b", Type::FieldElement, 1u32).into(), + ArrayExpressionInner::Identifier(Identifier::from(CoreIdentifier::Call(0))) + .annotate(Type::FieldElement, 1u32) + .into(), + ), + TypedStatement::Return(vec![]), + ], + signature: DeclarationSignature::new(), + }; + + let expected = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![( + DeclarationFunctionKey::with_location("main", "main") + .signature(DeclarationSignature::new()), + TypedFunctionSymbol::Here(expected_main), + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + assert_eq!(reduced.unwrap(), expected); + } + + #[test] + fn incompatible() { + // def foo(field[K] a) -> field[K]: + // return a + // def main(): + // field[1] b = foo([]) + // return + + // expected: + // Error: Incompatible + + let foo_signature = DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]) + .outputs(vec![DeclarationType::array(( + DeclarationType::FieldElement, + Constant::Generic("K"), + ))]); + + let foo: TypedFunction = TypedFunction { + arguments: vec![ + DeclarationVariable::array("a", DeclarationType::FieldElement, "K").into(), + ], + statements: vec![TypedStatement::Return(vec![ + ArrayExpressionInner::Identifier("a".into()) + .annotate(Type::FieldElement, 1u32) + .into(), + ])], + signature: foo_signature.clone(), + }; + + let main: TypedFunction = TypedFunction { + arguments: vec![], + statements: vec![ + TypedStatement::MultipleDefinition( + vec![Variable::array("b", Type::FieldElement, 1u32).into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + vec![None], + vec![ArrayExpressionInner::Value(vec![].into()) + .annotate(Type::FieldElement, 0u32) + .into()], + vec![Type::array((Type::FieldElement, 1u32))], + ), + ), + TypedStatement::Return(vec![]), + ], + signature: DeclarationSignature::new().inputs(vec![]).outputs(vec![]), + }; + + let p = TypedProgram { + main: "main".into(), + modules: vec![( + "main".into(), + TypedModule { + functions: vec![ + ( + DeclarationFunctionKey::with_location("main", "foo") + .signature(foo_signature.clone()), + TypedFunctionSymbol::Here(foo), + ), + ( + DeclarationFunctionKey::with_location("main", "main").signature( + DeclarationSignature::new().inputs(vec![]).outputs(vec![]), + ), + TypedFunctionSymbol::Here(main), + ), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; + + let reduced = reduce_program(p); + + assert_eq!( + reduced, + Err(Error::Incompatible("Call site `main/foo<_>(field[0]) -> field[1]` incompatible with declaration `main/foo(field[K]) -> field[K]`".into())) + ); + } +} diff --git a/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs new file mode 100644 index 00000000..98325f5e --- /dev/null +++ b/zokrates_core/src/static_analysis/reducer/shallow_ssa.rs @@ -0,0 +1,1024 @@ +// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can +// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility. +// Function calls are also left unvisited +// Saving the indices is not required for function calls, as they cannot modify their environment + +// Example: +// def main(field a) -> field: +// u32 n = 42 +// a = a + 1 +// field b = foo(a) +// for u32 i in 0..n: +// +// endfor +// return b + +// Should be turned into +// def main(field a_0) -> field: +// u32 n_0 = 42 +// a_1 = a_0 + 1 +// field b_0 = foo(a_1) // we keep the function call as is +// # versions: {n: 0, a: 1, b: 0} +// for u32 i_0 in 0..n_0: +// // we keep the loop body as is +// endfor +// return b_3 // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop + +use crate::typed_absy::folder::*; +use crate::typed_absy::types::ConcreteGenericsAssignment; +use crate::typed_absy::types::Type; +use crate::typed_absy::*; + +use zokrates_field::Field; + +use super::{Output, Versions}; + +pub struct ShallowTransformer<'ast, 'a> { + // version index for any variable name + pub versions: &'a mut Versions<'ast>, + // A backup of the versions before each for-loop + pub for_loop_backups: Vec>, + // whether all statements could be unrolled so far. Loops with variable bounds cannot. + pub blocked: bool, +} + +impl<'ast, 'a> ShallowTransformer<'ast, 'a> { + pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self { + ShallowTransformer { + versions, + for_loop_backups: Vec::default(), + blocked: false, + } + } + + // increase all versions by 2 and return the old versions + fn create_version_gap(&mut self) -> Versions<'ast> { + let ret = self.versions.clone(); + self.versions.values_mut().for_each(|v| *v += 2); + ret + } + + fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { + let version = *self + .versions + .entry(c_id.clone()) + .and_modify(|e| *e += 1) // if it was already declared, we increment + .or_insert(0); // otherwise, we start from this version + + Identifier::from(c_id).version(version) + } + + fn issue_next_ssa_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { + assert_eq!(v.id.version, 0); + + Variable { + id: self.issue_next_identifier(v.id.id), + ..v + } + } + + pub fn transform( + f: TypedFunction<'ast, T>, + generics: &ConcreteGenericsAssignment<'ast>, + versions: &'a mut Versions<'ast>, + ) -> Output, Vec>> { + let mut unroller = ShallowTransformer::with_versions(versions); + + let f = unroller.fold_function(f, generics); + + match unroller.blocked { + false => Output::Complete(f), + true => Output::Incomplete(f, unroller.for_loop_backups), + } + } + + fn fold_function( + &mut self, + f: TypedFunction<'ast, T>, + generics: &ConcreteGenericsAssignment<'ast>, + ) -> TypedFunction<'ast, T> { + let mut f = f; + + f.statements = generics + .0 + .iter() + .map(|(g, v)| { + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::with_id_and_type( + *g, + Type::Uint(UBitwidth::B32), + )), + UExpression::from(*v as u32).into(), + ) + }) + .chain(f.statements) + .collect(); + + for arg in &f.arguments { + let _ = self.issue_next_identifier(arg.id.id.id.clone()); + } + + fold_function(self, f) + } +} + +impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { + match s { + TypedStatement::Declaration(_) => vec![], + TypedStatement::Definition(a, e) => { + let e = self.fold_expression(e); + + let a = match a { + TypedAssignee::Identifier(v) => { + let v = self.issue_next_ssa_variable(v); + TypedAssignee::Identifier(self.fold_variable(v)) + } + a => fold_assignee(self, a), + }; + + vec![TypedStatement::Definition(a, e)] + } + TypedStatement::MultipleDefinition(assignees, exprs) => { + let exprs = self.fold_expression_list(exprs); + let assignees = assignees + .into_iter() + .map(|a| match a { + TypedAssignee::Identifier(v) => { + let v = self.issue_next_ssa_variable(v); + TypedAssignee::Identifier(self.fold_variable(v)) + } + a => fold_assignee(self, a), + }) + .collect(); + + vec![TypedStatement::MultipleDefinition(assignees, exprs)] + } + TypedStatement::For(v, from, to, stats) => { + let from = self.fold_uint_expression(from); + let to = self.fold_uint_expression(to); + self.blocked = true; + let versions_before_loop = self.create_version_gap(); + self.for_loop_backups.push(versions_before_loop); + vec![TypedStatement::For(v, from, to, stats)] + } + s => fold_statement(self, s), + } + } + + fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { + let res = Identifier { + version: *self.versions.get(&(n.id)).unwrap_or(&0), + ..n + }; + res + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + if let FieldElementExpression::FunctionCall(ref k, _, _) = e { + if !k.id.starts_with('_') { + self.blocked = true; + } + } + + fold_field_expression(self, e) + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + if let BooleanExpression::FunctionCall(ref k, _, _) = e { + if !k.id.starts_with('_') { + self.blocked = true; + } + }; + + fold_boolean_expression(self, e) + } + + fn fold_uint_expression_inner( + &mut self, + b: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + if let UExpressionInner::FunctionCall(ref k, _, _) = e { + if !k.id.starts_with('_') { + self.blocked = true; + } + }; + + fold_uint_expression_inner(self, b, e) + } + + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> ArrayExpressionInner<'ast, T> { + if let ArrayExpressionInner::FunctionCall(ref k, _, _) = e { + if !k.id.starts_with('_') { + self.blocked = true; + } + }; + + fold_array_expression_inner(self, ty, e) + } + + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + if let StructExpressionInner::FunctionCall(ref k, _, _) = e { + if !k.id.starts_with('_') { + self.blocked = true; + } + }; + + fold_struct_expression_inner(self, ty, e) + } + + fn fold_expression_list( + &mut self, + e: TypedExpressionList<'ast, T>, + ) -> TypedExpressionList<'ast, T> { + match e { + TypedExpressionList::FunctionCall(ref k, _, _, _) => { + if !k.id.starts_with('_') { + self.blocked = true; + } + } + _ => unreachable!(), + }; + + fold_expression_list(self, e) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::typed_absy::types::DeclarationSignature; + use zokrates_field::Bn128Field; + mod normal { + use super::*; + + #[test] + fn detect_non_constant_bound() { + let loops: Vec> = vec![TypedStatement::For( + Variable::uint("i", UBitwidth::B32), + UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), + 2u32.into(), + vec![], + )]; + + let statements = loops; + + let f = TypedFunction { + arguments: vec![], + signature: DeclarationSignature::new(), + statements, + }; + + match ShallowTransformer::transform( + f, + &ConcreteGenericsAssignment::default(), + &mut Versions::default(), + ) { + Output::Incomplete(..) => {} + _ => unreachable!(), + }; + } + + #[test] + fn definition() { + // field a + // a = 5 + // a = 6 + // a + + // should be turned into + // a_0 = 5 + // a_1 = 6 + // a_1 + + let mut versions = Versions::new(); + + let mut u = ShallowTransformer::with_versions(&mut versions); + let s: TypedStatement = + TypedStatement::Declaration(Variable::field_element("a")); + assert_eq!(u.fold_statement(s), vec![]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Number(Bn128Field::from(5)).into(), + ); + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(0) + )), + FieldElementExpression::Number(Bn128Field::from(5)).into() + )] + ); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Number(Bn128Field::from(6)).into(), + ); + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(1) + )), + FieldElementExpression::Number(Bn128Field::from(6)).into() + )] + ); + + let e: FieldElementExpression = + FieldElementExpression::Identifier("a".into()); + assert_eq!( + u.fold_field_expression(e), + FieldElementExpression::Identifier(Identifier::from("a").version(1)) + ); + } + + #[test] + fn incremental_definition() { + // field a + // a = 5 + // a = a + 1 + + // should be turned into + // a_0 = 5 + // a_1 = a_0 + 1 + + let mut versions = Versions::new(); + + let mut u = ShallowTransformer::with_versions(&mut versions); + + let s: TypedStatement = + TypedStatement::Declaration(Variable::field_element("a")); + assert_eq!(u.fold_statement(s), vec![]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Number(Bn128Field::from(5)).into(), + ); + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(0) + )), + FieldElementExpression::Number(Bn128Field::from(5)).into() + )] + ); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Add( + box FieldElementExpression::Identifier("a".into()), + box FieldElementExpression::Number(Bn128Field::from(1)), + ) + .into(), + ); + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(1) + )), + FieldElementExpression::Add( + box FieldElementExpression::Identifier(Identifier::from("a").version(0)), + box FieldElementExpression::Number(Bn128Field::from(1)) + ) + .into() + )] + ); + } + + #[test] + fn incremental_multiple_definition() { + use crate::typed_absy::types::Type; + + // field a + // a = 2 + // a = foo(a) + + // should be turned into + // a_0 = 2 + // a_1 = foo(a_0) + + let mut versions = Versions::new(); + + let mut u = ShallowTransformer::with_versions(&mut versions); + + let s: TypedStatement = + TypedStatement::Declaration(Variable::field_element("a")); + assert_eq!(u.fold_statement(s), vec![]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element("a")), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ); + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::field_element( + Identifier::from("a").version(0) + )), + FieldElementExpression::Number(Bn128Field::from(2)).into() + )] + ); + + let s: TypedStatement = TypedStatement::MultipleDefinition( + vec![Variable::field_element("a").into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + ), + vec![], + vec![FieldElementExpression::Identifier("a".into()).into()], + vec![Type::FieldElement], + ), + ); + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::MultipleDefinition( + vec![Variable::field_element(Identifier::from("a").version(1)).into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo").signature( + DeclarationSignature::new() + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]) + ), + vec![], + vec![ + FieldElementExpression::Identifier(Identifier::from("a").version(0)) + .into() + ], + vec![Type::FieldElement], + ) + )] + ); + } + + #[test] + fn incremental_array_definition() { + // field[2] a = [1, 1] + // a[1] = 2 + + // should be turned into + // a_0 = [1, 1] + // a_0[1] = 2 + + let mut versions = Versions::new(); + + let mut u = ShallowTransformer::with_versions(&mut versions); + + let s: TypedStatement = + TypedStatement::Declaration(Variable::array("a", Type::FieldElement, 2u32)); + assert_eq!(u.fold_statement(s), vec![]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::Number(Bn128Field::from(1)).into(), + ] + .into(), + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ); + + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::array( + Identifier::from("a").version(0), + Type::FieldElement, + 2u32 + )), + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(1)).into(), + FieldElementExpression::Number(Bn128Field::from(1)).into() + ] + .into() + ) + .annotate(Type::FieldElement, 2u32) + .into() + )] + ); + + let s: TypedStatement = TypedStatement::Definition( + TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), + box UExpression::from(1u32), + ), + FieldElementExpression::Number(Bn128Field::from(2)).into(), + ); + + assert_eq!(u.fold_statement(s.clone()), vec![s]); + } + + #[test] + fn incremental_array_of_arrays_definition() { + // field[2][2] a = [[0, 1], [2, 3]] + // a[1] = [4, 5] + + // should be turned into + // a_0 = [[0, 1], [2, 3]] + // a_0 = [4, 5] + + let mut versions = Versions::new(); + + let mut u = ShallowTransformer::with_versions(&mut versions); + + let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32)); + + let s: TypedStatement = TypedStatement::Declaration( + Variable::with_id_and_type("a", array_of_array_ty.clone()), + ); + assert_eq!(u.fold_statement(s), vec![]); + + let s = TypedStatement::Definition( + TypedAssignee::Identifier(Variable::with_id_and_type( + "a", + array_of_array_ty.clone(), + )), + ArrayExpressionInner::Value( + vec![ + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(0)).into(), + FieldElementExpression::Number(Bn128Field::from(1)).into(), + ] + .into(), + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(3)).into(), + ] + .into(), + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ] + .into(), + ) + .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) + .into(), + ); + + assert_eq!( + u.fold_statement(s), + vec![TypedStatement::Definition( + TypedAssignee::Identifier(Variable::with_id_and_type( + Identifier::from("a").version(0), + array_of_array_ty.clone(), + )), + ArrayExpressionInner::Value( + vec![ + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(0)).into(), + FieldElementExpression::Number(Bn128Field::from(1)).into(), + ] + .into() + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(2)).into(), + FieldElementExpression::Number(Bn128Field::from(3)).into(), + ] + .into() + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ] + .into() + ) + .annotate(Type::array((Type::FieldElement, 2u32)), 2u32) + .into(), + )] + ); + + let s: TypedStatement = TypedStatement::Definition( + TypedAssignee::Select( + box TypedAssignee::Identifier(Variable::with_id_and_type( + "a", + array_of_array_ty.clone(), + )), + box UExpression::from(1u32), + ), + ArrayExpressionInner::Value( + vec![ + FieldElementExpression::Number(Bn128Field::from(4)).into(), + FieldElementExpression::Number(Bn128Field::from(5)).into(), + ] + .into(), + ) + .annotate(Type::FieldElement, 2u32) + .into(), + ); + + assert_eq!(u.fold_statement(s.clone()), vec![s]); + } + } + + mod for_loop { + use super::*; + use crate::typed_absy::types::GGenericsAssignment; + #[test] + fn treat_loop() { + // def main(field a) -> field: + // u32 n = 42 + // n = n + // a = a + // for u32 i in n..n*n: + // a = a + // endfor + // a = a + // for u32 i in n..n*n: + // a = a + // endfor + // a = a + // return a + + // When called with K := 1, expected: + // def main(field a_0) -> field: + // u32 K = 1 + // u32 n_0 = 42 + // n_1 = n_0 + // a_1 = a_0 + // # versions: {n: 1, a: 1} + // for u32 i_0 in n_0..n_0*n_0: + // a_0 = a_0 + // endfor + // a_4 = a_3 + // # versions: {n: 3, a: 4} + // for u32 i_0 in n_0..n_0*n_0: + // a_0 = a_0 + // endfor + // a_7 = a_6 + // return a_7 + // # versions: {n: 5, a: 7} + + let f: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(42u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::For( + Variable::uint("i", UBitwidth::B32), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32) + * UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + vec![TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + )], + ), + TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::For( + Variable::uint("i", UBitwidth::B32), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + UExpressionInner::Identifier("n".into()).annotate(UBitwidth::B32) + * UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + vec![TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + )], + ), + TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::Return(vec![ + FieldElementExpression::Identifier("a".into()).into() + ]), + ], + signature: DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let mut versions = Versions::default(); + + let ssa = ShallowTransformer::transform( + f, + &GGenericsAssignment(vec![("K".into(), 1)].into_iter().collect()), + &mut versions, + ); + + let expected = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("K", UBitwidth::B32).into(), + TypedExpression::Uint(1u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(42u32.into()), + ), + TypedStatement::Definition( + Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(1)).into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::For( + Variable::uint("i", UBitwidth::B32), + UExpressionInner::Identifier(Identifier::from("n").version(1)) + .annotate(UBitwidth::B32) + .into(), + UExpressionInner::Identifier(Identifier::from("n").version(1)) + .annotate(UBitwidth::B32) + * UExpressionInner::Identifier(Identifier::from("n").version(1)) + .annotate(UBitwidth::B32) + .into(), + vec![TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + )], + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(4)).into(), + FieldElementExpression::Identifier(Identifier::from("a").version(3)).into(), + ), + TypedStatement::For( + Variable::uint("i", UBitwidth::B32), + UExpressionInner::Identifier(Identifier::from("n").version(3)) + .annotate(UBitwidth::B32) + .into(), + UExpressionInner::Identifier(Identifier::from("n").version(3)) + .annotate(UBitwidth::B32) + * UExpressionInner::Identifier(Identifier::from("n").version(3)) + .annotate(UBitwidth::B32) + .into(), + vec![TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + )], + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(7)).into(), + FieldElementExpression::Identifier(Identifier::from("a").version(6)).into(), + ), + TypedStatement::Return(vec![FieldElementExpression::Identifier( + Identifier::from("a").version(7), + ) + .into()]), + ], + signature: DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + assert_eq!( + versions, + vec![("n".into(), 5), ("a".into(), 7), ("K".into(), 4)] + .into_iter() + .collect::() + ); + + let expected = Output::Incomplete( + expected, + vec![ + vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)] + .into_iter() + .collect::(), + vec![("n".into(), 3), ("a".into(), 4), ("K".into(), 2)] + .into_iter() + .collect::(), + ], + ); + + assert_eq!(ssa, expected); + } + } + + mod function_call { + use super::*; + use crate::typed_absy::types::GGenericsAssignment; + // test that function calls are left in + #[test] + fn treat_calls() { + // def main(field a) -> field: + // u32 n = 42 + // n = n + // a = a + // a = foo::(a) + // n = n + // a = a * foo::(a) + // return a + + // When called with K := 1, expected: + // def main(field a_0) -> field: + // K = 1 + // u32 n_0 = 42 + // n_1 = n_0 + // a_1 = a_0 + // a_2 = foo::(a_1) + // n_2 = n_1 + // a_3 = a_2 * foo::(a_2) + // return a_3 + // # versions: {n: 2, a: 3} + + let f: TypedFunction = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(42u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element("a").into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::MultipleDefinition( + vec![Variable::field_element("a").into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo"), + vec![Some( + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + )], + vec![FieldElementExpression::Identifier("a".into()).into()], + vec![Type::FieldElement], + ), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element("a").into(), + (FieldElementExpression::Identifier("a".into()) + * FieldElementExpression::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo"), + vec![Some( + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + )], + vec![FieldElementExpression::Identifier("a".into()).into()], + )) + .into(), + ), + TypedStatement::Return(vec![ + FieldElementExpression::Identifier("a".into()).into() + ]), + ], + signature: DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + let mut versions = Versions::default(); + + let ssa = ShallowTransformer::transform( + f, + &GGenericsAssignment(vec![("K".into(), 1)].into_iter().collect()), + &mut versions, + ); + + let expected = TypedFunction { + arguments: vec![DeclarationVariable::field_element("a").into()], + statements: vec![ + TypedStatement::Definition( + Variable::uint("K", UBitwidth::B32).into(), + TypedExpression::Uint(1u32.into()), + ), + TypedStatement::Definition( + Variable::uint("n", UBitwidth::B32).into(), + TypedExpression::Uint(42u32.into()), + ), + TypedStatement::Definition( + Variable::uint(Identifier::from("n").version(1), UBitwidth::B32).into(), + UExpressionInner::Identifier("n".into()) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(1)).into(), + FieldElementExpression::Identifier("a".into()).into(), + ), + TypedStatement::MultipleDefinition( + vec![Variable::field_element(Identifier::from("a").version(2)).into()], + TypedExpressionList::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo"), + vec![Some( + UExpressionInner::Identifier(Identifier::from("n").version(1)) + .annotate(UBitwidth::B32) + .into(), + )], + vec![FieldElementExpression::Identifier( + Identifier::from("a").version(1), + ) + .into()], + vec![Type::FieldElement], + ), + ), + TypedStatement::Definition( + Variable::uint(Identifier::from("n").version(2), UBitwidth::B32).into(), + UExpressionInner::Identifier(Identifier::from("n").version(1)) + .annotate(UBitwidth::B32) + .into(), + ), + TypedStatement::Definition( + Variable::field_element(Identifier::from("a").version(3)).into(), + (FieldElementExpression::Identifier(Identifier::from("a").version(2)) + * FieldElementExpression::FunctionCall( + DeclarationFunctionKey::with_location("main", "foo"), + vec![Some( + UExpressionInner::Identifier(Identifier::from("n").version(2)) + .annotate(UBitwidth::B32) + .into(), + )], + vec![FieldElementExpression::Identifier( + Identifier::from("a").version(2), + ) + .into()], + )) + .into(), + ), + TypedStatement::Return(vec![FieldElementExpression::Identifier( + Identifier::from("a").version(3), + ) + .into()]), + ], + signature: DeclarationSignature::new() + .generics(vec![Some("K".into())]) + .inputs(vec![DeclarationType::FieldElement]) + .outputs(vec![DeclarationType::FieldElement]), + }; + + assert_eq!( + versions, + vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] + .into_iter() + .collect::() + ); + + assert_eq!(ssa, Output::Incomplete(expected, vec![],)); + } + } +} diff --git a/zokrates_core/src/static_analysis/return_binder.rs b/zokrates_core/src/static_analysis/return_binder.rs deleted file mode 100644 index 5320a7f9..00000000 --- a/zokrates_core/src/static_analysis/return_binder.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::typed_absy::folder::fold_statement; -use crate::typed_absy::identifier::CoreIdentifier; -use crate::typed_absy::*; -use zokrates_field::Field; - -pub struct ReturnBinder; - -impl ReturnBinder { - pub fn bind<'ast, T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { - ReturnBinder {}.fold_program(p) - } -} - -impl<'ast, T: Field> Folder<'ast, T> for ReturnBinder { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - match s { - TypedStatement::Return(exprs) => { - let ret_identifiers: Vec> = (0..exprs.len()) - .map(|i| CoreIdentifier::Internal("RETURN", i).into()) - .collect(); - - let ret_expressions: Vec> = exprs - .iter() - .zip(ret_identifiers.iter()) - .map(|(e, i)| match e.get_type() { - Type::FieldElement => FieldElementExpression::Identifier(i.clone()).into(), - Type::Boolean => BooleanExpression::Identifier(i.clone()).into(), - Type::Array(array_type) => ArrayExpressionInner::Identifier(i.clone()) - .annotate(*array_type.ty, array_type.size) - .into(), - Type::Struct(struct_type) => StructExpressionInner::Identifier(i.clone()) - .annotate(struct_type) - .into(), - Type::Uint(bitwidth) => UExpressionInner::Identifier(i.clone()) - .annotate(bitwidth) - .into(), - }) - .collect(); - - exprs - .into_iter() - .zip(ret_identifiers.iter()) - .map(|(e, i)| { - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::with_id_and_type( - i.clone(), - e.get_type(), - )), - e, - ) - }) - .chain(std::iter::once(TypedStatement::Return(ret_expressions))) - .collect() - } - s => fold_statement(self, s), - } - } -} diff --git a/zokrates_core/src/static_analysis/trimmer.rs b/zokrates_core/src/static_analysis/trimmer.rs new file mode 100644 index 00000000..1db07400 --- /dev/null +++ b/zokrates_core/src/static_analysis/trimmer.rs @@ -0,0 +1,88 @@ +use crate::typed_absy::TypedModule; +use crate::typed_absy::{TypedFunctionSymbol, TypedProgram}; +use zokrates_field::Field; + +pub struct Trimmer; + +impl Trimmer { + pub fn trim<'ast, T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { + let main_module_id = p.main; + + // get the main module + let main_module = p.modules.get(&main_module_id).unwrap().clone(); + + // get the main function in the main module + let (main_key, main_function) = main_module + .functions + .into_iter() + .find(|(k, _)| k.id == "main") + .unwrap() + .clone(); + + // define a function in the main module for the `unpack` embed + let unpack = crate::embed::FlatEmbed::Unpack(T::get_required_bits()); + let unpack_key = unpack.key::(); + + // define a function in the main module for the `u32_to_bits` embed + let u32_to_bits = crate::embed::FlatEmbed::U32ToBits; + let u32_to_bits_key = u32_to_bits.key::(); + + // define a function in the main module for the `u16_to_bits` embed + let u16_to_bits = crate::embed::FlatEmbed::U16ToBits; + let u16_to_bits_key = u16_to_bits.key::(); + + // define a function in the main module for the `u8_to_bits` embed + let u8_to_bits = crate::embed::FlatEmbed::U8ToBits; + let u8_to_bits_key = u8_to_bits.key::(); + + // define a function in the main module for the `u32_from_bits` embed + let u32_from_bits = crate::embed::FlatEmbed::U32FromBits; + let u32_from_bits_key = u32_from_bits.key::(); + + // define a function in the main module for the `u16_from_bits` embed + let u16_from_bits = crate::embed::FlatEmbed::U16FromBits; + let u16_from_bits_key = u16_from_bits.key::(); + + // define a function in the main module for the `u8_from_bits` embed + let u8_from_bits = crate::embed::FlatEmbed::U8FromBits; + let u8_from_bits_key = u8_from_bits.key::(); + + TypedProgram { + main: main_module_id.clone(), + modules: vec![( + main_module_id, + TypedModule { + functions: vec![ + (main_key, main_function), + (unpack_key.into(), TypedFunctionSymbol::Flat(unpack)), + ( + u32_from_bits_key.into(), + TypedFunctionSymbol::Flat(u32_from_bits), + ), + ( + u16_from_bits_key.into(), + TypedFunctionSymbol::Flat(u16_from_bits), + ), + ( + u8_from_bits_key.into(), + TypedFunctionSymbol::Flat(u8_from_bits), + ), + ( + u32_to_bits_key.into(), + TypedFunctionSymbol::Flat(u32_to_bits), + ), + ( + u16_to_bits_key.into(), + TypedFunctionSymbol::Flat(u16_to_bits), + ), + (u8_to_bits_key.into(), TypedFunctionSymbol::Flat(u8_to_bits)), + ] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + } + } +} diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index e2fc5a33..aad636ac 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -1,3 +1,4 @@ +use crate::embed::FlatEmbed; use crate::zir::folder::*; use crate::zir::*; use std::collections::HashMap; @@ -24,7 +25,7 @@ impl<'ast, T: Field> UintOptimizer<'ast, T> { } } -fn force_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> { +fn force_reduce(e: UExpression) -> UExpression { let metadata = e.metadata.unwrap(); let should_reduce = metadata.should_reduce.make_true(); @@ -38,7 +39,7 @@ fn force_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> } } -fn force_no_reduce<'ast, T: Field>(e: UExpression<'ast, T>) -> UExpression<'ast, T> { +fn force_no_reduce(e: UExpression) -> UExpression { let metadata = e.metadata.unwrap(); let should_reduce = metadata.should_reduce.make_false(); @@ -67,6 +68,42 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { BooleanExpression::UintEq(box left, box right) } + BooleanExpression::UintLt(box left, box right) => { + let left = self.fold_uint_expression(left); + let right = self.fold_uint_expression(right); + + let left = force_reduce(left); + let right = force_reduce(right); + + BooleanExpression::UintLt(box left, box right) + } + BooleanExpression::UintLe(box left, box right) => { + let left = self.fold_uint_expression(left); + let right = self.fold_uint_expression(right); + + let left = force_reduce(left); + let right = force_reduce(right); + + BooleanExpression::UintLe(box left, box right) + } + BooleanExpression::UintGt(box left, box right) => { + let left = self.fold_uint_expression(left); + let right = self.fold_uint_expression(right); + + let left = force_reduce(left); + let right = force_reduce(right); + + BooleanExpression::UintGt(box left, box right) + } + BooleanExpression::UintGe(box left, box right) => { + let left = self.fold_uint_expression(left); + let right = self.fold_uint_expression(right); + + let left = force_reduce(left); + let right = force_reduce(right); + + BooleanExpression::UintGe(box left, box right) + } e => fold_boolean_expression(self, e), } } @@ -94,7 +131,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { self.ids .get(&Variable::uint(id.clone(), range)) .cloned() - .expect(&format!("identifier should have been defined: {}", id)), + .unwrap_or_else(|| panic!("identifier should have been defined: {}", id)), ), Add(box left, box right) => { // reduce the two terms @@ -109,7 +146,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { .map(|max| (false, false, max)) .unwrap_or_else(|| { range_max - .clone() .checked_add(&right_max) .map(|max| (true, false, max)) .unwrap_or_else(|| { @@ -167,7 +203,7 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { // if and only if `right_bitwidth` is `T::get_required_bits() - 1`, then `offset` is out of the interval // [0, 2**(max_bitwidth)[, therefore we need to reduce `right` left_max - .checked_add(&target_offset.clone()) + .checked_add(&target_offset) .map(|max| (false, true, max)) .unwrap_or_else(|| (true, true, range_max.clone() + target_offset)) } else { @@ -234,7 +270,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { .map(|max| (false, false, max)) .unwrap_or_else(|| { range_max - .clone() .checked_mul(&right_max) .map(|max| (true, false, max)) .unwrap_or_else(|| { @@ -387,8 +422,8 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { .collect(), )], ZirStatement::MultipleDefinition(lhs, rhs) => match rhs { - ZirExpressionList::FunctionCall(key, arguments, ty) => match key.clone().id { - "_U32_FROM_BITS" => { + ZirExpressionList::EmbedCall(embed, generics, arguments) => match embed { + FlatEmbed::U32FromBits => { assert_eq!(lhs.len(), 1); self.register( lhs[0].clone(), @@ -397,12 +432,13 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { should_reduce: ShouldReduce::False, }, ); + vec![ZirStatement::MultipleDefinition( lhs, - ZirExpressionList::FunctionCall(key, arguments, ty), + ZirExpressionList::EmbedCall(embed, generics, arguments), )] } - "_U16_FROM_BITS" => { + FlatEmbed::U16FromBits => { assert_eq!(lhs.len(), 1); self.register( lhs[0].clone(), @@ -413,10 +449,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { ); vec![ZirStatement::MultipleDefinition( lhs, - ZirExpressionList::FunctionCall(key, arguments, ty), + ZirExpressionList::EmbedCall(embed, generics, arguments), )] } - "_U8_FROM_BITS" => { + FlatEmbed::U8FromBits => { assert_eq!(lhs.len(), 1); self.register( lhs[0].clone(), @@ -427,18 +463,18 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { ); vec![ZirStatement::MultipleDefinition( lhs, - ZirExpressionList::FunctionCall(key, arguments, ty), + ZirExpressionList::EmbedCall(embed, generics, arguments), )] } _ => vec![ZirStatement::MultipleDefinition( lhs, - ZirExpressionList::FunctionCall( - key, + ZirExpressionList::EmbedCall( + embed, + generics, arguments .into_iter() .map(|e| self.fold_expression(e)) .collect(), - ty, ), )], }, diff --git a/zokrates_core/src/static_analysis/unroll.rs b/zokrates_core/src/static_analysis/unroll.rs deleted file mode 100644 index 567c4f7b..00000000 --- a/zokrates_core/src/static_analysis/unroll.rs +++ /dev/null @@ -1,616 +0,0 @@ -//! Module containing SSA reduction, including for-loop unrolling -//! -//! @file unroll.rs -//! @author Thibaut Schaeffer -//! @date 2018 - -use crate::typed_absy::folder::*; -use crate::typed_absy::identifier::CoreIdentifier; -use crate::typed_absy::*; -use std::collections::HashMap; -use zokrates_field::Field; - -pub enum Output<'ast, T: Field> { - Complete(TypedProgram<'ast, T>), - Incomplete(TypedProgram<'ast, T>, usize), -} - -pub struct Unroller<'ast> { - // version index for any variable name - substitution: HashMap, usize>, - // whether all statements could be unrolled so far. Loops with variable bounds cannot. - complete: bool, - statement_count: usize, -} - -impl<'ast> Unroller<'ast> { - fn new() -> Self { - Unroller { - substitution: HashMap::new(), - complete: true, - statement_count: 0, - } - } - - fn issue_next_ssa_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> { - let res = match self.substitution.get(&v.id.id) { - Some(i) => Variable { - id: Identifier { - id: v.id.id.clone(), - version: i + 1, - stack: vec![], - }, - ..v - }, - None => Variable { ..v.clone() }, - }; - - self.substitution - .entry(v.id.id) - .and_modify(|e| *e += 1) - .or_insert(0); - res - } - - pub fn unroll(p: TypedProgram) -> Output { - let mut unroller = Unroller::new(); - let p = unroller.fold_program(p); - - match unroller.complete { - true => Output::Complete(p), - false => Output::Incomplete(p, unroller.statement_count), - } - } -} - -impl<'ast, T: Field> Folder<'ast, T> for Unroller<'ast> { - fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { - self.statement_count += 1; - match s { - TypedStatement::Declaration(_) => vec![], - TypedStatement::Definition(a, e) => { - let e = self.fold_expression(e); - - let a = match a { - TypedAssignee::Identifier(v) => { - TypedAssignee::Identifier(self.issue_next_ssa_variable(v)) - } - a => fold_assignee(self, a), - }; - - vec![TypedStatement::Definition(a, e)] - } - TypedStatement::MultipleDefinition(assignees, exprs) => { - let exprs = self.fold_expression_list(exprs); - let assignees = assignees - .into_iter() - .map(|a| match a { - TypedAssignee::Identifier(v) => { - TypedAssignee::Identifier(self.issue_next_ssa_variable(v)) - } - a => fold_assignee(self, a), - }) - .collect(); - - vec![TypedStatement::MultipleDefinition(assignees, exprs)] - } - TypedStatement::For(v, from, to, stats) => { - let from = self.fold_field_expression(from); - let to = self.fold_field_expression(to); - - match (from, to) { - (FieldElementExpression::Number(from), FieldElementExpression::Number(to)) => { - let mut values: Vec = vec![]; - let mut current = from; - while current < to { - values.push(current.clone()); - current = T::one() + ¤t; - } - - let res = values - .into_iter() - .map(|index| { - vec![ - vec![ - TypedStatement::Declaration(v.clone()), - TypedStatement::Definition( - TypedAssignee::Identifier(v.clone()), - FieldElementExpression::Number(index).into(), - ), - ], - stats.clone(), - ] - .into_iter() - .flat_map(|x| x) - }) - .flat_map(|x| x) - .flat_map(|x| self.fold_statement(x)) - .collect(); - - res - } - (from, to) => { - self.complete = false; - vec![TypedStatement::For(v, from, to, stats)] - } - } - } - s => fold_statement(self, s), - } - } - - fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { - self.substitution = HashMap::new(); - for arg in &f.arguments { - self.substitution.insert(arg.id.id.id.clone(), 0); - } - - fold_function(self, f) - } - - fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - Identifier { - version: self.substitution.get(&n.id).unwrap_or(&0).clone(), - ..n - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use zokrates_field::Bn128Field; - - #[cfg(test)] - mod statement { - use super::*; - use crate::typed_absy::types::{FunctionKey, Signature}; - - #[test] - fn for_loop() { - // for field i in 2..5 - // field foo = i - - // should be unrolled to - // i_0 = 2 - // foo_0 = i_0 - // i_1 = 3 - // foo_1 = i_1 - // i_2 = 4 - // foo_2 = i_2 - - let s = TypedStatement::For( - Variable::field_element("i"), - FieldElementExpression::Number(Bn128Field::from(2)), - FieldElementExpression::Number(Bn128Field::from(5)), - vec![ - TypedStatement::Declaration(Variable::field_element("foo")), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("foo")), - FieldElementExpression::Identifier("i".into()).into(), - ), - ], - ); - - let expected = vec![ - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("i").version(0), - )), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("foo").version(0), - )), - FieldElementExpression::Identifier(Identifier::from("i").version(0)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("i").version(1), - )), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("foo").version(1), - )), - FieldElementExpression::Identifier(Identifier::from("i").version(1)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("i").version(2), - )), - FieldElementExpression::Number(Bn128Field::from(4)).into(), - ), - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("foo").version(2), - )), - FieldElementExpression::Identifier(Identifier::from("i").version(2)).into(), - ), - ]; - - let mut u = Unroller::new(); - - assert_eq!(u.fold_statement(s), expected); - } - - #[test] - fn idempotence() { - // an already unrolled program should not be modified by unrolling again - - // b = [5] - // b[0] = 1 - // a = 5 - // a_1 = 6 - // a_2 = 7 - - // should be turned into - // b = [5] - // b[0] = 1 - // a = 5 - // a_1 = 6 - // a_2 = 7 - - let mut u = Unroller::new(); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::array( - Identifier::from("b").version(0), - Type::FieldElement, - 1, - )), - ArrayExpressionInner::Value(vec![FieldElementExpression::from(Bn128Field::from( - 5, - )) - .into()]) - .annotate(Type::FieldElement, 1) - .into(), - ); - assert_eq!(u.fold_statement(s.clone()), vec![s]); - - let s = TypedStatement::Definition( - TypedAssignee::Select( - box Variable::field_element(Identifier::from("b").version(0)).into(), - box FieldElementExpression::Number(Bn128Field::from(0)), - ), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ); - assert_eq!(u.fold_statement(s.clone()), vec![s]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(0), - )), - FieldElementExpression::Number(Bn128Field::from(5)).into(), - ); - assert_eq!(u.fold_statement(s.clone()), vec![s]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(1), - )), - FieldElementExpression::Number(Bn128Field::from(6)).into(), - ); - assert_eq!(u.fold_statement(s.clone()), vec![s]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(2), - )), - FieldElementExpression::Number(Bn128Field::from(7)).into(), - ); - assert_eq!(u.fold_statement(s.clone()), vec![s]); - } - - #[test] - fn definition() { - // field a - // a = 5 - // a = 6 - // a - - // should be turned into - // a_0 = 5 - // a_1 = 6 - // a_1 - - let mut u = Unroller::new(); - - let s: TypedStatement = - TypedStatement::Declaration(Variable::field_element("a")); - assert_eq!(u.fold_statement(s), vec![]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(5)).into(), - ); - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(0) - )), - FieldElementExpression::Number(Bn128Field::from(5)).into() - )] - ); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(6)).into(), - ); - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(1) - )), - FieldElementExpression::Number(Bn128Field::from(6)).into() - )] - ); - - let e: FieldElementExpression = - FieldElementExpression::Identifier("a".into()); - assert_eq!( - u.fold_field_expression(e), - FieldElementExpression::Identifier(Identifier::from("a").version(1)) - ); - } - - #[test] - fn incremental_definition() { - // field a - // a = 5 - // a = a + 1 - - // should be turned into - // a_0 = 5 - // a_1 = a_0 + 1 - - let mut u = Unroller::new(); - - let s: TypedStatement = - TypedStatement::Declaration(Variable::field_element("a")); - assert_eq!(u.fold_statement(s), vec![]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(5)).into(), - ); - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(0) - )), - FieldElementExpression::Number(Bn128Field::from(5)).into() - )] - ); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Add( - box FieldElementExpression::Identifier("a".into()), - box FieldElementExpression::Number(Bn128Field::from(1)), - ) - .into(), - ); - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(1) - )), - FieldElementExpression::Add( - box FieldElementExpression::Identifier(Identifier::from("a").version(0)), - box FieldElementExpression::Number(Bn128Field::from(1)) - ) - .into() - )] - ); - } - - #[test] - fn incremental_multiple_definition() { - use crate::typed_absy::types::Type; - - // field a - // a = 2 - // a = foo(a) - - // should be turned into - // a_0 = 2 - // a_1 = foo(a_0) - - let mut u = Unroller::new(); - - let s: TypedStatement = - TypedStatement::Declaration(Variable::field_element("a")); - assert_eq!(u.fold_statement(s), vec![]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element("a")), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ); - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_element( - Identifier::from("a").version(0) - )), - FieldElementExpression::Number(Bn128Field::from(2)).into() - )] - ); - - let s: TypedStatement = TypedStatement::MultipleDefinition( - vec![Variable::field_element("a").into()], - TypedExpressionList::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]), - ), - vec![FieldElementExpression::Identifier("a".into()).into()], - vec![Type::FieldElement], - ), - ); - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::MultipleDefinition( - vec![Variable::field_element(Identifier::from("a").version(1)).into()], - TypedExpressionList::FunctionCall( - FunctionKey::with_id("foo").signature( - Signature::new() - .inputs(vec![Type::FieldElement]) - .outputs(vec![Type::FieldElement]) - ), - vec![ - FieldElementExpression::Identifier(Identifier::from("a").version(0)) - .into() - ], - vec![Type::FieldElement], - ) - )] - ); - } - - #[test] - fn incremental_array_definition() { - // field[2] a = [1, 1] - // a[1] = 2 - - // should be turned into - // a_0 = [1, 1] - // a_0[1] = 2 - - let mut u = Unroller::new(); - - let s: TypedStatement = - TypedStatement::Declaration(Variable::field_array("a", 2)); - assert_eq!(u.fold_statement(s), vec![]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_array("a", 2)), - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ]) - .annotate(Type::FieldElement, 2) - .into(), - ); - - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::field_array( - Identifier::from("a").version(0), - 2 - )), - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(1)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into() - ]) - .annotate(Type::FieldElement, 2) - .into() - )] - ); - - let s: TypedStatement = TypedStatement::Definition( - TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::field_array("a", 2)), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), - FieldElementExpression::Number(Bn128Field::from(2)).into(), - ); - - assert_eq!(u.fold_statement(s.clone()), vec![s]); - } - - #[test] - fn incremental_array_of_arrays_definition() { - // field[2][2] a = [[0, 1], [2, 3]] - // a[1] = [4, 5] - - // should be turned into - // a_0 = [[0, 1], [2, 3]] - // a_0 = [4, 5] - - let mut u = Unroller::new(); - - let array_of_array_ty = Type::array(Type::array(Type::FieldElement, 2), 2); - - let s: TypedStatement = TypedStatement::Declaration( - Variable::with_id_and_type("a", array_of_array_ty.clone()), - ); - assert_eq!(u.fold_statement(s), vec![]); - - let s = TypedStatement::Definition( - TypedAssignee::Identifier(Variable::with_id_and_type( - "a", - array_of_array_ty.clone(), - )), - ArrayExpressionInner::Value(vec![ - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(0)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ]) - .annotate(Type::FieldElement, 2) - .into(), - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ]) - .annotate(Type::FieldElement, 2) - .into(), - ]) - .annotate(Type::array(Type::FieldElement, 2), 2) - .into(), - ); - - assert_eq!( - u.fold_statement(s), - vec![TypedStatement::Definition( - TypedAssignee::Identifier(Variable::with_id_and_type( - Identifier::from("a").version(0), - array_of_array_ty.clone(), - )), - ArrayExpressionInner::Value(vec![ - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(0)).into(), - FieldElementExpression::Number(Bn128Field::from(1)).into(), - ]) - .annotate(Type::FieldElement, 2) - .into(), - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(2)).into(), - FieldElementExpression::Number(Bn128Field::from(3)).into(), - ]) - .annotate(Type::FieldElement, 2) - .into(), - ]) - .annotate(Type::array(Type::FieldElement, 2), 2) - .into(), - )] - ); - - let s: TypedStatement = TypedStatement::Definition( - TypedAssignee::Select( - box TypedAssignee::Identifier(Variable::with_id_and_type( - "a", - array_of_array_ty.clone(), - )), - box FieldElementExpression::Number(Bn128Field::from(1)), - ), - ArrayExpressionInner::Value(vec![ - FieldElementExpression::Number(Bn128Field::from(4)).into(), - FieldElementExpression::Number(Bn128Field::from(5)).into(), - ]) - .annotate(Type::FieldElement, 2) - .into(), - ); - - assert_eq!(u.fold_statement(s.clone()), vec![s]); - } - } -} diff --git a/zokrates_core/src/static_analysis/variable_read_remover.rs b/zokrates_core/src/static_analysis/variable_read_remover.rs index 563dcb4e..90c25f71 100644 --- a/zokrates_core/src/static_analysis/variable_read_remover.rs +++ b/zokrates_core/src/static_analysis/variable_read_remover.rs @@ -29,41 +29,50 @@ impl<'ast, T: Field> VariableReadRemover<'ast, T> { fn select + IfElse<'ast, T>>( &mut self, a: ArrayExpression<'ast, T>, - i: FieldElementExpression<'ast, T>, + i: UExpression<'ast, T>, ) -> U { - match i { - FieldElementExpression::Number(i) => U::select(a, FieldElementExpression::Number(i)), + match i.into_inner() { + UExpressionInner::Value(i) => { + U::select(a, UExpressionInner::Value(i).annotate(UBitwidth::B32)) + } i => { let size = match a.get_type().clone() { - Type::Array(array_ty) => array_ty.size, + Type::Array(array_ty) => match array_ty.size.into_inner() { + UExpressionInner::Value(size) => size as u32, + _ => unreachable!(), + }, _ => unreachable!(), }; self.statements.push(TypedStatement::Assertion( (0..size) .map(|index| { - BooleanExpression::FieldEq( - box i.clone(), - box FieldElementExpression::Number(index.into()).into(), + BooleanExpression::UintEq( + box i.clone().annotate(UBitwidth::B32), + box index.into(), ) }) .fold(None, |acc, e| match acc { Some(acc) => Some(BooleanExpression::Or(box acc, box e)), None => Some(e), }) - .unwrap() - .into(), + .unwrap(), )); (0..size) - .map(|i| U::select(a.clone(), FieldElementExpression::Number(i.into()))) + .map(|i| { + U::select( + a.clone(), + UExpressionInner::Value(i.into()).annotate(UBitwidth::B32), + ) + }) .enumerate() .rev() .fold(None, |acc, (index, res)| match acc { Some(acc) => Some(U::if_else( - BooleanExpression::FieldEq( - box i.clone(), - box FieldElementExpression::Number(index.into()), + BooleanExpression::UintEq( + box i.clone().annotate(UBitwidth::B32), + box (index as u32).into(), ), res, acc, @@ -99,21 +108,20 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableReadRemover<'ast, T> { fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { ArrayExpressionInner::Select(box a, box i) => { self.select::>(a, i).into_inner() } - e => fold_array_expression_inner(self, ty, size, e), + e => fold_array_expression_inner(self, ty, e), } } fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { @@ -160,8 +168,8 @@ mod tests { let access: TypedStatement = TypedStatement::Definition( TypedAssignee::Identifier(Variable::field_element("b")), FieldElementExpression::Select( - box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2), - box FieldElementExpression::Identifier("i".into()), + box ArrayExpressionInner::Identifier("a".into()).annotate(Type::FieldElement, 2u32), + box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), ) .into(), ); @@ -169,35 +177,32 @@ mod tests { assert_eq!( VariableReadRemover::new().fold_statement(access), vec![ - TypedStatement::Assertion( - BooleanExpression::Or( - box BooleanExpression::FieldEq( - box FieldElementExpression::Identifier("i".into()), - box FieldElementExpression::Number(0.into()) - ), - box BooleanExpression::FieldEq( - box FieldElementExpression::Identifier("i".into()), - box FieldElementExpression::Number(1.into()) - ) + TypedStatement::Assertion(BooleanExpression::Or( + box BooleanExpression::UintEq( + box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32) + ), + box BooleanExpression::UintEq( + box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(1).annotate(UBitwidth::B32) ) - .into(), - ), + )), TypedStatement::Definition( TypedAssignee::Identifier(Variable::field_element("b")), FieldElementExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Identifier("i".into()), - box FieldElementExpression::Number(0.into()) + BooleanExpression::UintEq( + box UExpressionInner::Identifier("i".into()).annotate(UBitwidth::B32), + box UExpressionInner::Value(0).annotate(UBitwidth::B32) ), FieldElementExpression::Select( box ArrayExpressionInner::Identifier("a".into()) - .annotate(Type::FieldElement, 2), - box FieldElementExpression::Number(0.into()), + .annotate(Type::FieldElement, 2u32), + box 0u32.into(), ), FieldElementExpression::Select( box ArrayExpressionInner::Identifier("a".into()) - .annotate(Type::FieldElement, 2), - box FieldElementExpression::Number(1.into()), + .annotate(Type::FieldElement, 2u32), + box 1u32.into(), ) ) .into() diff --git a/zokrates_core/src/static_analysis/variable_write_remover.rs b/zokrates_core/src/static_analysis/variable_write_remover.rs index 24aa8f57..fba777bf 100644 --- a/zokrates_core/src/static_analysis/variable_write_remover.rs +++ b/zokrates_core/src/static_analysis/variable_write_remover.rs @@ -37,33 +37,31 @@ impl<'ast> VariableWriteRemover { let inner_ty = base.inner_type(); let size = base.size(); + let size = match size.as_inner() { + UExpressionInner::Value(v) => *v as u32, + _ => unreachable!(), + }; + let head = indices.remove(0); let tail = indices; match head { Access::Select(head) => { statements.insert(TypedStatement::Assertion( - BooleanExpression::Lt( - box head.clone(), - box FieldElementExpression::Number(T::from(size)), - ) - .into(), + BooleanExpression::UintLt(box head.clone(), box size.into()), )); ArrayExpressionInner::Value( (0..size) .map(|i| match inner_ty { + Type::Int => unreachable!(), Type::Array(..) => ArrayExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), + BooleanExpression::UintEq( + box i.into(), box head.clone(), ), match Self::choose_many( - ArrayExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), + ArrayExpression::select(base.clone(), i).into(), tail.clone(), new_expression.clone(), statements, @@ -74,23 +72,16 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - ArrayExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), + ArrayExpression::select(base.clone(), i), ) .into(), Type::Struct(..) => StructExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), + BooleanExpression::UintEq( + box i.into(), box head.clone(), ), match Self::choose_many( - StructExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), + StructExpression::select(base.clone(), i).into(), tail.clone(), new_expression.clone(), statements, @@ -101,23 +92,17 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - StructExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), + StructExpression::select(base.clone(), i), ) .into(), Type::FieldElement => FieldElementExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), + BooleanExpression::UintEq( + box i.into(), box head.clone(), ), match Self::choose_many( - FieldElementExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), + FieldElementExpression::select(base.clone(), i) + .into(), tail.clone(), new_expression.clone(), statements, @@ -128,23 +113,16 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - FieldElementExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), + FieldElementExpression::select(base.clone(), i), ) .into(), Type::Boolean => BooleanExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), + BooleanExpression::UintEq( + box i.into(), box head.clone(), ), match Self::choose_many( - BooleanExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), + BooleanExpression::select(base.clone(), i).into(), tail.clone(), new_expression.clone(), statements, @@ -155,23 +133,16 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - BooleanExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), + BooleanExpression::select(base.clone(), i), ) .into(), Type::Uint(..) => UExpression::if_else( - BooleanExpression::FieldEq( - box FieldElementExpression::Number(T::from(i)), + BooleanExpression::UintEq( + box i.into(), box head.clone(), ), match Self::choose_many( - UExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ) - .into(), + UExpression::select(base.clone(), i).into(), tail.clone(), new_expression.clone(), statements, @@ -182,14 +153,12 @@ impl<'ast> VariableWriteRemover { e.get_type() ), }, - UExpression::select( - base.clone(), - FieldElementExpression::Number(T::from(i)), - ), + UExpression::select(base.clone(), i), ) .into(), }) - .collect(), + .collect::>() + .into(), ) .annotate(inner_ty.clone(), size) .into() @@ -212,6 +181,7 @@ impl<'ast> VariableWriteRemover { .clone() .into_iter() .map(|member| match *member.ty { + Type::Int => unreachable!(), Type::FieldElement => { if member.id == head { Self::choose_many( @@ -225,11 +195,8 @@ impl<'ast> VariableWriteRemover { statements, ) } else { - FieldElementExpression::member( - base.clone(), - member.id.clone(), - ) - .into() + FieldElementExpression::member(base.clone(), member.id) + .into() } } Type::Uint(..) => { @@ -242,8 +209,7 @@ impl<'ast> VariableWriteRemover { statements, ) } else { - UExpression::member(base.clone(), member.id.clone()) - .into() + UExpression::member(base.clone(), member.id).into() } } Type::Boolean => { @@ -259,11 +225,8 @@ impl<'ast> VariableWriteRemover { statements, ) } else { - BooleanExpression::member( - base.clone(), - member.id.clone(), - ) - .into() + BooleanExpression::member(base.clone(), member.id) + .into() } } Type::Array(..) => { @@ -276,8 +239,7 @@ impl<'ast> VariableWriteRemover { statements, ) } else { - ArrayExpression::member(base.clone(), member.id.clone()) - .into() + ArrayExpression::member(base.clone(), member.id).into() } } Type::Struct(..) => { @@ -293,11 +255,7 @@ impl<'ast> VariableWriteRemover { statements, ) } else { - StructExpression::member( - base.clone(), - member.id.clone(), - ) - .into() + StructExpression::member(base.clone(), member.id).into() } } }) @@ -316,12 +274,12 @@ impl<'ast> VariableWriteRemover { #[derive(Clone, Debug)] enum Access<'ast, T: Field> { - Select(FieldElementExpression<'ast, T>), + Select(UExpression<'ast, T>), Member(MemberId), } /// Turn an assignee into its representation as a base variable and a list accesses /// a[2][3][4] -> (a, [2, 3, 4]) -fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec>) { +fn linear(a: TypedAssignee) -> (Variable, Vec>) { match a { TypedAssignee::Identifier(v) => (v, vec![]), TypedAssignee::Select(box array, box index) => { @@ -337,11 +295,11 @@ fn linear<'ast, T: Field>(a: TypedAssignee<'ast, T>) -> (Variable, Vec(assignee: &TypedAssignee<'ast, T>) -> bool { +fn is_constant(assignee: &TypedAssignee) -> bool { match assignee { TypedAssignee::Identifier(_) => true, - TypedAssignee::Select(box assignee, box index) => match index { - FieldElementExpression::Number(_) => is_constant(assignee), + TypedAssignee::Select(box assignee, box index) => match index.as_inner() { + UExpressionInner::Value(_) => is_constant(assignee), _ => false, }, TypedAssignee::Member(box assignee, _) => is_constant(assignee), @@ -362,24 +320,21 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover { let (variable, indices) = linear(assignee); let base = match variable.get_type() { + Type::Int => unreachable!(), Type::FieldElement => { - FieldElementExpression::Identifier(variable.id.clone().into()).into() - } - Type::Boolean => { - BooleanExpression::Identifier(variable.id.clone().into()).into() - } - Type::Uint(bitwidth) => { - UExpressionInner::Identifier(variable.id.clone().into()) - .annotate(bitwidth) - .into() + FieldElementExpression::Identifier(variable.id.clone()).into() } + Type::Boolean => BooleanExpression::Identifier(variable.id.clone()).into(), + Type::Uint(bitwidth) => UExpressionInner::Identifier(variable.id.clone()) + .annotate(bitwidth) + .into(), Type::Array(array_type) => { - ArrayExpressionInner::Identifier(variable.id.clone().into()) + ArrayExpressionInner::Identifier(variable.id.clone()) .annotate(*array_type.ty, array_type.size) .into() } Type::Struct(members) => { - StructExpressionInner::Identifier(variable.id.clone().into()) + StructExpressionInner::Identifier(variable.id.clone()) .annotate(members) .into() } @@ -390,7 +345,7 @@ impl<'ast, T: Field> Folder<'ast, T> for VariableWriteRemover { let indices = indices .into_iter() .map(|a| match a { - Access::Select(i) => Access::Select(self.fold_field_expression(i)), + Access::Select(i) => Access::Select(self.fold_uint_expression(i)), a => a, }) .collect(); diff --git a/zokrates_core/src/typed_absy/abi.rs b/zokrates_core/src/typed_absy/abi.rs index f8ff7a87..355f7225 100644 --- a/zokrates_core/src/typed_absy/abi.rs +++ b/zokrates_core/src/typed_absy/abi.rs @@ -1,15 +1,15 @@ -use crate::typed_absy::{Signature, Type}; +use crate::typed_absy::types::{ConcreteSignature, ConcreteType}; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct AbiInput { pub name: String, pub public: bool, #[serde(flatten)] - pub ty: Type, + pub ty: ConcreteType, } -pub type AbiOutput = Type; +pub type AbiOutput = ConcreteType; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct Abi { @@ -18,8 +18,9 @@ pub struct Abi { } impl Abi { - pub fn signature(&self) -> Signature { - Signature { + pub fn signature(&self) -> ConcreteSignature { + ConcreteSignature { + generics: vec![], inputs: self.inputs.iter().map(|i| i.ty.clone()).collect(), outputs: self.outputs.clone(), } @@ -29,10 +30,12 @@ impl Abi { #[cfg(test)] mod tests { use super::*; - use crate::typed_absy::types::{ArrayType, FunctionKey, StructMember, StructType}; + use crate::typed_absy::types::{ + ConcreteArrayType, ConcreteFunctionKey, ConcreteStructMember, ConcreteStructType, UBitwidth, + }; use crate::typed_absy::{ - Parameter, Type, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, UBitwidth, - Variable, + parameter::DeclarationParameter, variable::DeclarationVariable, ConcreteType, + TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, }; use std::collections::HashMap; use zokrates_field::Bn128Field; @@ -41,22 +44,25 @@ mod tests { fn generate_abi_from_typed_ast() { let mut functions = HashMap::new(); functions.insert( - FunctionKey::with_id("main"), + ConcreteFunctionKey::with_location("main", "main").into(), TypedFunctionSymbol::Here(TypedFunction { arguments: vec![ - Parameter { - id: Variable::field_element("a"), + DeclarationParameter { + id: DeclarationVariable::field_element("a"), private: true, - }, - Parameter { - id: Variable::boolean("b"), + } + .into(), + DeclarationParameter { + id: DeclarationVariable::boolean("b"), private: false, - }, + } + .into(), ], statements: vec![], - signature: Signature::new() - .inputs(vec![Type::FieldElement, Type::Boolean]) - .outputs(vec![Type::FieldElement]), + signature: ConcreteSignature::new() + .inputs(vec![ConcreteType::FieldElement, ConcreteType::Boolean]) + .outputs(vec![ConcreteType::FieldElement]) + .into(), }), ); @@ -74,15 +80,15 @@ mod tests { AbiInput { name: String::from("a"), public: false, - ty: Type::FieldElement, + ty: ConcreteType::FieldElement, }, AbiInput { name: String::from("b"), public: true, - ty: Type::Boolean, + ty: ConcreteType::Boolean, }, ], - outputs: vec![Type::FieldElement], + outputs: vec![ConcreteType::FieldElement], }; assert_eq!(expected_abi, abi); @@ -101,6 +107,19 @@ mod tests { assert_eq!(de_abi, abi); } + #[test] + #[should_panic] + fn serialize_integer() { + // serializing the Int type should panic as it is not allowed in signatures + + let abi: Abi = Abi { + inputs: vec![], + outputs: vec![ConcreteType::Int], + }; + + let _ = serde_json::to_string_pretty(&abi).unwrap(); + } + #[test] fn serialize_field() { let abi: Abi = Abi { @@ -108,15 +127,15 @@ mod tests { AbiInput { name: String::from("a"), public: true, - ty: Type::FieldElement, + ty: ConcreteType::FieldElement, }, AbiInput { name: String::from("b"), public: true, - ty: Type::FieldElement, + ty: ConcreteType::FieldElement, }, ], - outputs: vec![Type::FieldElement], + outputs: vec![ConcreteType::FieldElement], }; let json = serde_json::to_string_pretty(&abi).unwrap(); @@ -154,17 +173,17 @@ mod tests { AbiInput { name: String::from("a"), public: true, - ty: Type::Uint(UBitwidth::B8), + ty: ConcreteType::Uint(UBitwidth::B8), }, AbiInput { name: String::from("b"), public: true, - ty: Type::Uint(UBitwidth::B16), + ty: ConcreteType::Uint(UBitwidth::B16), }, AbiInput { name: String::from("c"), public: true, - ty: Type::Uint(UBitwidth::B32), + ty: ConcreteType::Uint(UBitwidth::B32), }, ], outputs: vec![], @@ -205,21 +224,21 @@ mod tests { inputs: vec![AbiInput { name: String::from("foo"), public: true, - ty: Type::Struct(StructType::new( + ty: ConcreteType::Struct(ConcreteStructType::new( "".into(), "Foo".into(), vec![ - StructMember::new(String::from("a"), Type::FieldElement), - StructMember::new(String::from("b"), Type::Boolean), + ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement), + ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean), ], )), }], - outputs: vec![Type::Struct(StructType::new( + outputs: vec![ConcreteType::Struct(ConcreteStructType::new( "".into(), "Foo".into(), vec![ - StructMember::new(String::from("a"), Type::FieldElement), - StructMember::new(String::from("b"), Type::Boolean), + ConcreteStructMember::new(String::from("a"), ConcreteType::FieldElement), + ConcreteStructMember::new(String::from("b"), ConcreteType::Boolean), ], ))], }; @@ -279,17 +298,23 @@ mod tests { inputs: vec![AbiInput { name: String::from("foo"), public: true, - ty: Type::Struct(StructType::new( + ty: ConcreteType::Struct(ConcreteStructType::new( "".into(), "Foo".into(), - vec![StructMember::new( + vec![ConcreteStructMember::new( String::from("bar"), - Type::Struct(StructType::new( + ConcreteType::Struct(ConcreteStructType::new( "".into(), "Bar".into(), vec![ - StructMember::new(String::from("a"), Type::FieldElement), - StructMember::new(String::from("b"), Type::FieldElement), + ConcreteStructMember::new( + String::from("a"), + ConcreteType::FieldElement, + ), + ConcreteStructMember::new( + String::from("b"), + ConcreteType::FieldElement, + ), ], )), )], @@ -345,19 +370,22 @@ mod tests { inputs: vec![AbiInput { name: String::from("a"), public: false, - ty: Type::Array(ArrayType::new( - Type::Struct(StructType::new( + ty: ConcreteType::Array(ConcreteArrayType::new( + ConcreteType::Struct(ConcreteStructType::new( "".into(), "Foo".into(), vec![ - StructMember::new(String::from("b"), Type::FieldElement), - StructMember::new(String::from("c"), Type::Boolean), + ConcreteStructMember::new( + String::from("b"), + ConcreteType::FieldElement, + ), + ConcreteStructMember::new(String::from("c"), ConcreteType::Boolean), ], )), 2, )), }], - outputs: vec![Type::Boolean], + outputs: vec![ConcreteType::Boolean], }; let json = serde_json::to_string_pretty(&abi).unwrap(); @@ -406,12 +434,12 @@ mod tests { inputs: vec![AbiInput { name: String::from("a"), public: false, - ty: Type::Array(ArrayType::new( - Type::Array(ArrayType::new(Type::FieldElement, 2)), + ty: ConcreteType::Array(ConcreteArrayType::new( + ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2)), 2, )), }], - outputs: vec![Type::FieldElement], + outputs: vec![ConcreteType::FieldElement], }; let json = serde_json::to_string_pretty(&abi).unwrap(); diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index cfc37a5e..5aaacc15 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -1,5 +1,6 @@ // Generic walk through a typed AST. Not mutating in place +use crate::typed_absy::types::{ArrayType, StructMember, StructType}; use crate::typed_absy::*; use zokrates_field::Field; @@ -23,9 +24,9 @@ pub trait Folder<'ast, T: Field>: Sized { fold_function(self, f) } - fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> { - Parameter { - id: self.fold_variable(p.id), + fn fold_parameter(&mut self, p: DeclarationParameter<'ast>) -> DeclarationParameter<'ast> { + DeclarationParameter { + id: self.fold_declaration_variable(p.id), ..p } } @@ -34,13 +35,58 @@ pub trait Folder<'ast, T: Field>: Sized { n } - fn fold_variable(&mut self, v: Variable<'ast>) -> Variable<'ast> { + fn fold_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { Variable { id: self.fold_name(v.id), - ..v + _type: self.fold_type(v._type), } } + fn fold_declaration_variable( + &mut self, + v: DeclarationVariable<'ast>, + ) -> DeclarationVariable<'ast> { + DeclarationVariable { + id: self.fold_name(v.id), + _type: self.fold_declaration_type(v._type), + } + } + + fn fold_type(&mut self, t: Type<'ast, T>) -> Type<'ast, T> { + use self::GType::*; + + match t { + Array(array_type) => Array(self.fold_array_type(array_type)), + Struct(struct_type) => Struct(self.fold_struct_type(struct_type)), + t => t, + } + } + + fn fold_array_type(&mut self, t: ArrayType<'ast, T>) -> ArrayType<'ast, T> { + ArrayType { + ty: box self.fold_type(*t.ty), + size: self.fold_uint_expression(t.size), + } + } + + fn fold_struct_type(&mut self, t: StructType<'ast, T>) -> StructType<'ast, T> { + StructType { + members: t + .members + .into_iter() + .map(|m| StructMember { + ty: box self.fold_type(*m.ty), + ..m + }) + .collect(), + ..t + } + } + + fn fold_declaration_type(&mut self, t: DeclarationType<'ast>) -> DeclarationType<'ast> { + t + } + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { fold_assignee(self, a) } @@ -49,6 +95,26 @@ pub trait Folder<'ast, T: Field>: Sized { fold_statement(self, s) } + fn fold_expression_or_spread( + &mut self, + e: TypedExpressionOrSpread<'ast, T>, + ) -> TypedExpressionOrSpread<'ast, T> { + match e { + TypedExpressionOrSpread::Expression(e) => { + TypedExpressionOrSpread::Expression(self.fold_expression(e)) + } + TypedExpressionOrSpread::Spread(s) => { + TypedExpressionOrSpread::Spread(self.fold_spread(s)) + } + } + } + + fn fold_spread(&mut self, s: TypedSpread<'ast, T>) -> TypedSpread<'ast, T> { + TypedSpread { + array: self.fold_array_expression(s.array), + } + } + fn fold_expression(&mut self, e: TypedExpression<'ast, T>) -> TypedExpression<'ast, T> { match e { TypedExpression::FieldElement(e) => self.fold_field_expression(e).into(), @@ -56,6 +122,7 @@ pub trait Folder<'ast, T: Field>: Sized { TypedExpression::Uint(e) => self.fold_uint_expression(e).into(), TypedExpression::Array(e) => self.fold_array_expression(e).into(), TypedExpression::Struct(e) => self.fold_struct_expression(e).into(), + TypedExpression::Int(e) => self.fold_int_expression(e).into(), } } @@ -74,18 +141,11 @@ pub trait Folder<'ast, T: Field>: Sized { &mut self, es: TypedExpressionList<'ast, T>, ) -> TypedExpressionList<'ast, T> { - match es { - TypedExpressionList::FunctionCall(id, arguments, types) => { - TypedExpressionList::FunctionCall( - id, - arguments - .into_iter() - .map(|a| self.fold_expression(a)) - .collect(), - types, - ) - } - } + fold_expression_list(self, es) + } + + fn fold_int_expression(&mut self, e: IntExpression<'ast, T>) -> IntExpression<'ast, T> { + fold_int_expression(self, e) } fn fold_field_expression( @@ -114,15 +174,14 @@ pub trait Folder<'ast, T: Field>: Sized { fn fold_array_expression_inner( &mut self, - ty: &Type, - size: usize, + ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { - fold_array_expression_inner(self, ty, size, e) + fold_array_expression_inner(self, ty, e) } fn fold_struct_expression_inner( &mut self, - ty: &StructType, + ty: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { fold_struct_expression_inner(self, ty, e) @@ -139,7 +198,6 @@ pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( .into_iter() .map(|(key, fun)| (key, f.fold_function_symbol(fun))) .collect(), - ..p } } @@ -161,8 +219,8 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)), TypedStatement::For(v, from, to, statements) => TypedStatement::For( f.fold_variable(v), - from, - to, + f.fold_uint_expression(from), + f.fold_uint_expression(to), statements .into_iter() .flat_map(|s| f.fold_statement(s)) @@ -172,24 +230,31 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( assignees.into_iter().map(|a| f.fold_assignee(a)).collect(), f.fold_expression_list(elist), ), + s => s, }; vec![res] } pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - _: &Type, - _: usize, + _: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)), - ArrayExpressionInner::Value(exprs) => { - ArrayExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()) - } - ArrayExpressionInner::FunctionCall(id, exps) => { + ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value( + exprs + .into_iter() + .map(|e| f.fold_expression_or_spread(e)) + .collect(), + ), + ArrayExpressionInner::FunctionCall(id, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(); let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); - ArrayExpressionInner::FunctionCall(id, exps) + ArrayExpressionInner::FunctionCall(id, generics, exps) } ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { ArrayExpressionInner::IfElse( @@ -204,15 +269,26 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } ArrayExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); ArrayExpressionInner::Select(box array, box index) } + ArrayExpressionInner::Slice(box array, box from, box to) => { + let array = f.fold_array_expression(array); + let from = f.fold_uint_expression(from); + let to = f.fold_uint_expression(to); + ArrayExpressionInner::Slice(box array, box from, box to) + } + ArrayExpressionInner::Repeat(box e, box count) => { + let e = f.fold_expression(e); + let count = f.fold_uint_expression(count); + ArrayExpressionInner::Repeat(box e, box count) + } } } pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, - _: &StructType, + _: &StructType<'ast, T>, e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { @@ -220,9 +296,13 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( StructExpressionInner::Value(exprs) => { StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()) } - StructExpressionInner::FunctionCall(id, exps) => { + StructExpressionInner::FunctionCall(id, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(); let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); - StructExpressionInner::FunctionCall(id, exps) + StructExpressionInner::FunctionCall(id, generics, exps) } StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { StructExpressionInner::IfElse( @@ -237,7 +317,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } StructExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); StructExpressionInner::Select(box array, box index) } } @@ -274,7 +354,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } FieldElementExpression::Pow(box e1, box e2) => { let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); + let e2 = f.fold_uint_expression(e2); FieldElementExpression::Pow(box e1, box e2) } FieldElementExpression::Neg(box e) => { @@ -293,9 +373,13 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( let alt = f.fold_field_expression(alt); FieldElementExpression::IfElse(box cond, box cons, box alt) } - FieldElementExpression::FunctionCall(key, exps) => { + FieldElementExpression::FunctionCall(key, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(); let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); - FieldElementExpression::FunctionCall(key, exps) + FieldElementExpression::FunctionCall(key, generics, exps) } FieldElementExpression::Member(box s, id) => { let s = f.fold_struct_expression(s); @@ -303,12 +387,19 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } FieldElementExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); FieldElementExpression::Select(box array, box index) } } } +pub fn fold_int_expression<'ast, T: Field, F: Folder<'ast, T>>( + _: &mut F, + _: IntExpression<'ast, T>, +) -> IntExpression<'ast, T> { + unreachable!() +} + pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, @@ -341,25 +432,45 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( let e2 = f.fold_uint_expression(e2); BooleanExpression::UintEq(box e1, box e2) } - BooleanExpression::Lt(box e1, box e2) => { + BooleanExpression::FieldLt(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Lt(box e1, box e2) + BooleanExpression::FieldLt(box e1, box e2) } - BooleanExpression::Le(box e1, box e2) => { + BooleanExpression::FieldLe(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Le(box e1, box e2) + BooleanExpression::FieldLe(box e1, box e2) } - BooleanExpression::Gt(box e1, box e2) => { + BooleanExpression::FieldGt(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Gt(box e1, box e2) + BooleanExpression::FieldGt(box e1, box e2) } - BooleanExpression::Ge(box e1, box e2) => { + BooleanExpression::FieldGe(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Ge(box e1, box e2) + BooleanExpression::FieldGe(box e1, box e2) + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintLt(box e1, box e2) + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintLe(box e1, box e2) + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintGt(box e1, box e2) + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintGe(box e1, box e2) } BooleanExpression::Or(box e1, box e2) => { let e1 = f.fold_boolean_expression(e1); @@ -375,9 +486,13 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( let e = f.fold_boolean_expression(e); BooleanExpression::Not(box e) } - BooleanExpression::FunctionCall(key, exps) => { + BooleanExpression::FunctionCall(key, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(); let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); - BooleanExpression::FunctionCall(key, exps) + BooleanExpression::FunctionCall(key, generics, exps) } BooleanExpression::IfElse(box cond, box cons, box alt) => { let cond = f.fold_boolean_expression(cond); @@ -391,7 +506,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( } BooleanExpression::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); BooleanExpression::Select(box array, box index) } } @@ -427,6 +542,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( UExpressionInner::Sub(box left, box right) } + UExpressionInner::FloorSub(box left, box right) => { + let left = f.fold_uint_expression(left); + let right = f.fold_uint_expression(right); + + UExpressionInner::FloorSub(box left, box right) + } UExpressionInner::Mult(box left, box right) => { let left = f.fold_uint_expression(left); let right = f.fold_uint_expression(right); @@ -490,13 +611,17 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( UExpressionInner::Pos(box e) } - UExpressionInner::FunctionCall(key, exps) => { + UExpressionInner::FunctionCall(key, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(); let exps = exps.into_iter().map(|e| f.fold_expression(e)).collect(); - UExpressionInner::FunctionCall(key, exps) + UExpressionInner::FunctionCall(key, generics, exps) } UExpressionInner::Select(box array, box index) => { let array = f.fold_array_expression(array); - let index = f.fold_field_expression(index); + let index = f.fold_uint_expression(index); UExpressionInner::Select(box array, box index) } UExpressionInner::IfElse(box cond, box cons, box alt) => { @@ -535,9 +660,44 @@ pub fn fold_array_expression<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, e: ArrayExpression<'ast, T>, ) -> ArrayExpression<'ast, T> { + let ty = f.fold_array_type(*e.ty); + ArrayExpression { - inner: f.fold_array_expression_inner(&e.ty, e.size, e.inner), - ..e + inner: f.fold_array_expression_inner(&ty, e.inner), + ty: box ty, + } +} + +pub fn fold_expression_list<'ast, T: Field, F: Folder<'ast, T>>( + f: &mut F, + es: TypedExpressionList<'ast, T>, +) -> TypedExpressionList<'ast, T> { + match es { + TypedExpressionList::FunctionCall(id, generics, arguments, types) => { + TypedExpressionList::FunctionCall( + id, + generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g))) + .collect(), + arguments + .into_iter() + .map(|a| f.fold_expression(a)) + .collect(), + types.into_iter().map(|t| f.fold_type(t)).collect(), + ) + } + TypedExpressionList::EmbedCall(embed, generics, arguments, types) => { + TypedExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|a| f.fold_expression(a)) + .collect(), + types.into_iter().map(|t| f.fold_type(t)).collect(), + ) + } } } @@ -568,7 +728,7 @@ pub fn fold_assignee<'ast, T: Field, F: Folder<'ast, T>>( match a { TypedAssignee::Identifier(v) => TypedAssignee::Identifier(f.fold_variable(v)), TypedAssignee::Select(box a, box index) => { - TypedAssignee::Select(box f.fold_assignee(a), box f.fold_field_expression(index)) + TypedAssignee::Select(box f.fold_assignee(a), box f.fold_uint_expression(index)) } TypedAssignee::Member(box s, m) => TypedAssignee::Member(box f.fold_assignee(s), m), } diff --git a/zokrates_core/src/typed_absy/identifier.rs b/zokrates_core/src/typed_absy/identifier.rs index eef72c20..145950e0 100644 --- a/zokrates_core/src/typed_absy/identifier.rs +++ b/zokrates_core/src/typed_absy/identifier.rs @@ -1,12 +1,11 @@ -use crate::typed_absy::types::FunctionKeyHash; -use crate::typed_absy::TypedModuleId; +use std::convert::TryInto; use std::fmt; #[derive(Debug, PartialEq, Clone, Hash, Eq)] pub enum CoreIdentifier<'ast> { Source(&'ast str), Internal(&'static str, usize), - Call(FunctionKeyHash, usize), + Call(usize), } impl<'ast> fmt::Display for CoreIdentifier<'ast> { @@ -14,11 +13,17 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { match self { CoreIdentifier::Source(s) => write!(f, "{}", s), CoreIdentifier::Internal(s, i) => write!(f, "#INTERNAL#_{}_{}", s, i), - CoreIdentifier::Call(k, i) => write!(f, "{:x}_{}", k, i), + CoreIdentifier::Call(i) => write!(f, "#CALL_RETURN_AT_INDEX_{}", i), } } } +impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { + fn from(s: &str) -> CoreIdentifier { + CoreIdentifier::Source(s) + } +} + /// A identifier for a variable #[derive(Debug, PartialEq, Clone, Hash, Eq)] pub struct Identifier<'ast> { @@ -26,31 +31,25 @@ pub struct Identifier<'ast> { pub id: CoreIdentifier<'ast>, /// the version of the variable, used after SSA transformation pub version: usize, - /// the call stack of the variable, used when inlining - pub stack: Vec<(TypedModuleId, FunctionKeyHash, usize)>, +} + +impl<'ast> TryInto<&'ast str> for Identifier<'ast> { + type Error = (); + + fn try_into(self) -> Result<&'ast str, Self::Error> { + match self.id { + CoreIdentifier::Source(i) => Ok(i), + _ => Err(()), + } + } } impl<'ast> fmt::Display for Identifier<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.stack.len() == 0 && self.version == 0 { + if self.version == 0 { write!(f, "{}", self.id) } else { - write!( - f, - "{}_{}_{}", - self.stack - .iter() - .map(|(name, key_hash, count)| format!( - "{}_{}_{}", - name.display(), - key_hash, - count - )) - .collect::>() - .join("_"), - self.id, - self.version - ) + write!(f, "{}_{}", self.id, self.version) } } } @@ -63,23 +62,13 @@ impl<'ast> From<&'ast str> for Identifier<'ast> { impl<'ast> From> for Identifier<'ast> { fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> { - Identifier { - id, - version: 0, - stack: vec![], - } + Identifier { id, version: 0 } } } -#[cfg(test)] impl<'ast> Identifier<'ast> { pub fn version(mut self, version: usize) -> Self { self.version = version; self } - - pub fn stack(mut self, stack: Vec<(TypedModuleId, FunctionKeyHash, usize)>) -> Self { - self.stack = stack; - self - } } diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs new file mode 100644 index 00000000..d15d549a --- /dev/null +++ b/zokrates_core/src/typed_absy/integer.rs @@ -0,0 +1,671 @@ +use crate::typed_absy::types::{ArrayType, Type}; +use crate::typed_absy::UBitwidth; +use crate::typed_absy::{ + ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse, + Select, StructExpression, Typed, TypedExpression, TypedExpressionOrSpread, TypedSpread, + UExpression, UExpressionInner, +}; +use num_bigint::BigUint; +use std::convert::TryFrom; +use std::fmt; +use std::ops::{Add, Div, Mul, Not, Rem, Sub}; +use zokrates_field::Field; + +type TypedExpressionPair<'ast, T> = (TypedExpression<'ast, T>, TypedExpression<'ast, T>); + +impl<'ast, T: Field> TypedExpressionOrSpread<'ast, T> { + pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result)> { + match e { + TypedExpressionOrSpread::Expression(e) => TypedExpression::align_to_type(e, ty) + .map(|e| e.into()) + .map_err(|(e, t)| (e.into(), t)), + TypedExpressionOrSpread::Spread(s) => { + ArrayExpression::try_from_int(s.array, ty.clone()) + .map(|e| TypedExpressionOrSpread::Spread(TypedSpread { array: e })) + .map_err(|e| (e.into(), ty)) + } + } + } +} + +impl<'ast, T: Field> TypedExpression<'ast, T> { + // return two TypedExpression, replacing IntExpression by FieldElement or Uint to try to align the two types if possible. + // Post condition is that (lhs, rhs) cannot be made equal by further removing IntExpressions + pub fn align_without_integers( + lhs: Self, + rhs: Self, + ) -> Result, TypedExpressionPair<'ast, T>> { + use self::TypedExpression::*; + + match (lhs, rhs) { + (Int(lhs), FieldElement(rhs)) => Ok(( + FieldElementExpression::try_from_int(lhs) + .map_err(|lhs| (lhs.into(), rhs.clone().into()))? + .into(), + FieldElement(rhs), + )), + (FieldElement(lhs), Int(rhs)) => Ok(( + FieldElement(lhs.clone()), + FieldElementExpression::try_from_int(rhs) + .map_err(|rhs| (lhs.into(), rhs.into()))? + .into(), + )), + (Int(lhs), Uint(rhs)) => Ok(( + UExpression::try_from_int(lhs, rhs.bitwidth()) + .map_err(|lhs| (lhs.into(), rhs.clone().into()))? + .into(), + Uint(rhs), + )), + (Uint(lhs), Int(rhs)) => { + let bitwidth = lhs.bitwidth(); + Ok(( + Uint(lhs.clone()), + UExpression::try_from_int(rhs, bitwidth) + .map_err(|rhs| (lhs.into(), rhs.into()))? + .into(), + )) + } + (Array(lhs), Array(rhs)) => { + fn get_common_type<'a, T: Field>( + t: Type<'a, T>, + u: Type<'a, T>, + ) -> Result, ()> { + match (t, u) { + (Type::Int, Type::Int) => Err(()), + (Type::Int, u) => Ok(u), + (t, Type::Int) => Ok(t), + (Type::Array(t), Type::Array(u)) => Ok(Type::Array(ArrayType::new( + get_common_type(*t.ty, *u.ty)?, + t.size, + ))), + (t, _) => Ok(t), + } + } + + let common_type = + get_common_type(lhs.inner_type().clone(), rhs.inner_type().clone()) + .map_err(|_| (lhs.clone().into(), rhs.clone().into()))?; + + Ok(( + ArrayExpression::try_from_int(lhs.clone(), common_type.clone()) + .map_err(|lhs| (lhs.clone(), rhs.clone().into()))? + .into(), + ArrayExpression::try_from_int(rhs, common_type) + .map_err(|rhs| (lhs.clone().into(), rhs.clone()))? + .into(), + )) + } + (Struct(lhs), Struct(rhs)) => { + if lhs.get_type() == rhs.get_type() { + Ok((Struct(lhs), Struct(rhs))) + } else { + Err((Struct(lhs), Struct(rhs))) + } + } + (Uint(lhs), Uint(rhs)) => Ok((lhs.into(), rhs.into())), + (Boolean(lhs), Boolean(rhs)) => Ok((lhs.into(), rhs.into())), + (FieldElement(lhs), FieldElement(rhs)) => Ok((lhs.into(), rhs.into())), + (Int(lhs), Int(rhs)) => Ok((lhs.into(), rhs.into())), + (lhs, rhs) => Err((lhs, rhs)), + } + } + + pub fn align_to_type(e: Self, ty: Type<'ast, T>) -> Result)> { + match ty.clone() { + Type::FieldElement => { + FieldElementExpression::try_from_typed(e).map(TypedExpression::from) + } + Type::Boolean => BooleanExpression::try_from_typed(e).map(TypedExpression::from), + Type::Uint(bitwidth) => { + UExpression::try_from_typed(e, bitwidth).map(TypedExpression::from) + } + Type::Array(array_ty) => { + ArrayExpression::try_from_typed(e, *array_ty.ty).map(TypedExpression::from) + } + Type::Struct(struct_ty) => { + StructExpression::try_from_typed(e, struct_ty).map(TypedExpression::from) + } + Type::Int => Err(e), + } + .map_err(|e| (e, ty)) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum IntExpression<'ast, T> { + Value(BigUint), + Pos(Box>), + Neg(Box>), + Add(Box>, Box>), + Sub(Box>, Box>), + Mult(Box>, Box>), + Div(Box>, Box>), + Rem(Box>, Box>), + Pow(Box>, Box>), + IfElse( + Box>, + Box>, + Box>, + ), + Select(Box>, Box>), + Xor(Box>, Box>), + And(Box>, Box>), + Or(Box>, Box>), + Not(Box>), + LeftShift(Box>, Box>), + RightShift(Box>, Box>), +} + +impl<'ast, T> Add for IntExpression<'ast, T> { + type Output = Self; + + fn add(self, other: Self) -> Self { + IntExpression::Add(box self, box other) + } +} + +impl<'ast, T> Sub for IntExpression<'ast, T> { + type Output = Self; + + fn sub(self, other: Self) -> Self { + IntExpression::Sub(box self, box other) + } +} + +impl<'ast, T> Mul for IntExpression<'ast, T> { + type Output = Self; + + fn mul(self, other: Self) -> Self { + IntExpression::Mult(box self, box other) + } +} + +impl<'ast, T> Div for IntExpression<'ast, T> { + type Output = Self; + + fn div(self, other: Self) -> Self { + IntExpression::Div(box self, box other) + } +} + +impl<'ast, T> Rem for IntExpression<'ast, T> { + type Output = Self; + + fn rem(self, other: Self) -> Self { + IntExpression::Rem(box self, box other) + } +} + +impl<'ast, T> Not for IntExpression<'ast, T> { + type Output = Self; + + fn not(self) -> Self { + IntExpression::Not(box self) + } +} + +impl<'ast, T> IntExpression<'ast, T> { + pub fn pow(self, other: Self) -> Self { + IntExpression::Pow(box self, box other) + } + + pub fn and(self, other: Self) -> Self { + IntExpression::And(box self, box other) + } + + pub fn xor(self, other: Self) -> Self { + IntExpression::Xor(box self, box other) + } + + pub fn or(self, other: Self) -> Self { + IntExpression::Or(box self, box other) + } + + pub fn left_shift(self, by: UExpression<'ast, T>) -> Self { + IntExpression::LeftShift(box self, box by) + } + + pub fn right_shift(self, by: UExpression<'ast, T>) -> Self { + IntExpression::RightShift(box self, box by) + } + + pub fn pos(self) -> Self { + IntExpression::Pos(box self) + } + + pub fn neg(self) -> Self { + IntExpression::Neg(box self) + } +} + +impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + IntExpression::Value(ref v) => write!(f, "{}", v), + IntExpression::Pos(ref e) => write!(f, "(+{})", e), + IntExpression::Neg(ref e) => write!(f, "(-{})", e), + IntExpression::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), + IntExpression::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), + IntExpression::Pow(ref lhs, ref rhs) => write!(f, "({} ** {})", lhs, rhs), + IntExpression::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), + IntExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), + IntExpression::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), + IntExpression::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs), + IntExpression::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), + IntExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), + IntExpression::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), + IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), + IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), + IntExpression::Not(ref e) => write!(f, "!{}", e), + IntExpression::IfElse(ref condition, ref consequent, ref alternative) => write!( + f, + "if {} then {} else {} fi", + condition, consequent, alternative + ), + } + } +} + +impl<'ast, T: Field> BooleanExpression<'ast, T> { + pub fn try_from_typed(e: TypedExpression<'ast, T>) -> Result> { + match e { + TypedExpression::Boolean(e) => Ok(e), + e => Err(e), + } + } +} + +impl<'ast, T: Field> FieldElementExpression<'ast, T> { + pub fn try_from_typed(e: TypedExpression<'ast, T>) -> Result> { + match e { + TypedExpression::FieldElement(e) => Ok(e), + TypedExpression::Int(e) => { + Self::try_from_int(e.clone()).map_err(|_| TypedExpression::Int(e)) + } + e => Err(e), + } + } + + pub fn try_from_int(i: IntExpression<'ast, T>) -> Result> { + match i { + IntExpression::Value(i) => Ok(Self::Number(T::try_from(i.clone()).map_err(|_| i)?)), + IntExpression::Add(box e1, box e2) => Ok(Self::Add( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::Sub(box e1, box e2) => Ok(Self::Sub( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::Mult(box e1, box e2) => Ok(Self::Mult( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::Pow(box e1, box e2) => Ok(Self::Pow( + box Self::try_from_int(e1)?, + box UExpression::try_from_int(e2, UBitwidth::B32)?, + )), + IntExpression::Div(box e1, box e2) => Ok(Self::Div( + box Self::try_from_int(e1)?, + box Self::try_from_int(e2)?, + )), + IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)), + IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)), + IntExpression::IfElse(box condition, box consequence, box alternative) => { + Ok(Self::IfElse( + box condition, + box Self::try_from_int(consequence)?, + box Self::try_from_int(alternative)?, + )) + } + IntExpression::Select(box array, box index) => { + let size = array.size(); + + match array.into_inner() { + ArrayExpressionInner::Value(values) => { + let values = values + .into_iter() + .map(|v| { + TypedExpressionOrSpread::align_to_type(v, Type::FieldElement) + .map_err(|(e, _)| match e { + TypedExpressionOrSpread::Expression(e) => { + IntExpression::try_from(e).unwrap() + } + TypedExpressionOrSpread::Spread(a) => { + IntExpression::select(a.array, 0u32) + } + }) + }) + .collect::, _>>()?; + Ok(FieldElementExpression::select( + ArrayExpressionInner::Value(values.into()) + .annotate(Type::FieldElement, size), + index, + )) + } + _ => unreachable!(), + } + } + i => Err(i), + } + } +} + +impl<'ast, T: Field> UExpression<'ast, T> { + pub fn try_from_typed( + e: TypedExpression<'ast, T>, + bitwidth: UBitwidth, + ) -> Result> { + match e { + TypedExpression::Uint(e) => match e.bitwidth == bitwidth { + true => Ok(e), + _ => Err(TypedExpression::Uint(e)), + }, + TypedExpression::Int(e) => { + Self::try_from_int(e.clone(), bitwidth).map_err(|_| TypedExpression::Int(e)) + } + e => Err(e), + } + } + + pub fn try_from_int( + i: IntExpression<'ast, T>, + bitwidth: UBitwidth, + ) -> Result> { + use self::IntExpression::*; + + match i { + Value(i) => { + if i <= BigUint::from(2u128.pow(bitwidth.to_usize() as u32) - 1) { + Ok(UExpressionInner::Value( + u128::from_str_radix(&i.to_str_radix(16), 16).unwrap(), + ) + .annotate(bitwidth)) + } else { + Err(Value(i)) + } + } + Add(box e1, box e2) => { + Ok(Self::try_from_int(e1, bitwidth)? + Self::try_from_int(e2, bitwidth)?) + } + Pos(box e) => Ok(Self::pos(Self::try_from_int(e, bitwidth)?)), + Neg(box e) => Ok(Self::neg(Self::try_from_int(e, bitwidth)?)), + Sub(box e1, box e2) => { + Ok(Self::try_from_int(e1, bitwidth)? - Self::try_from_int(e2, bitwidth)?) + } + Mult(box e1, box e2) => { + Ok(Self::try_from_int(e1, bitwidth)? * Self::try_from_int(e2, bitwidth)?) + } + Div(box e1, box e2) => { + Ok(Self::try_from_int(e1, bitwidth)? / Self::try_from_int(e2, bitwidth)?) + } + Rem(box e1, box e2) => { + Ok(Self::try_from_int(e1, bitwidth)? % Self::try_from_int(e2, bitwidth)?) + } + And(box e1, box e2) => Ok(UExpression::and( + Self::try_from_int(e1, bitwidth)?, + Self::try_from_int(e2, bitwidth)?, + )), + Or(box e1, box e2) => Ok(UExpression::or( + Self::try_from_int(e1, bitwidth)?, + Self::try_from_int(e2, bitwidth)?, + )), + Not(box e) => Ok(!Self::try_from_int(e, bitwidth)?), + Xor(box e1, box e2) => Ok(UExpression::xor( + Self::try_from_int(e1, bitwidth)?, + Self::try_from_int(e2, bitwidth)?, + )), + RightShift(box e1, box e2) => Ok(UExpression::right_shift( + Self::try_from_int(e1, bitwidth)?, + e2, + )), + LeftShift(box e1, box e2) => Ok(UExpression::left_shift( + Self::try_from_int(e1, bitwidth)?, + e2, + )), + IfElse(box condition, box consequence, box alternative) => Ok(UExpression::if_else( + condition, + Self::try_from_int(consequence, bitwidth)?, + Self::try_from_int(alternative, bitwidth)?, + )), + Select(box array, box index) => { + let size = array.size(); + match array.into_inner() { + ArrayExpressionInner::Value(values) => { + let values = values + .into_iter() + .map(|v| { + TypedExpressionOrSpread::align_to_type(v, Type::Uint(bitwidth)) + .map_err(|(e, _)| match e { + TypedExpressionOrSpread::Expression(e) => { + IntExpression::try_from(e).unwrap() + } + TypedExpressionOrSpread::Spread(a) => { + IntExpression::select(a.array, 0u32) + } + }) + }) + .collect::, _>>()?; + Ok(UExpression::select( + ArrayExpressionInner::Value(values.into()) + .annotate(Type::Uint(bitwidth), size), + index, + )) + } + _ => unreachable!(), + } + } + i => Err(i), + } + } +} + +impl<'ast, T: Field> ArrayExpression<'ast, T> { + pub fn try_from_typed( + e: TypedExpression<'ast, T>, + target_inner_ty: Type<'ast, T>, + ) -> Result> { + match e { + TypedExpression::Array(e) => Self::try_from_int(e.clone(), target_inner_ty) + .map_err(|_| TypedExpression::Array(e)), + e => Err(e), + } + } + + // precondition: `array` is only made of inline arrays unless it does not contain the Integer type + pub fn try_from_int( + array: Self, + target_inner_ty: Type<'ast, T>, + ) -> Result> { + let array_ty = array.get_array_type(); + + // elements must fit in the target type + match array.into_inner() { + ArrayExpressionInner::Value(inline_array) => { + let res = match target_inner_ty.clone() { + Type::Int => Ok(inline_array), + t => { + // try to convert all elements to the target type + inline_array + .into_iter() + .map(|v| { + TypedExpressionOrSpread::align_to_type(v, t.clone()).map_err( + |(e, _)| match e { + TypedExpressionOrSpread::Expression(e) => e, + TypedExpressionOrSpread::Spread(a) => { + TypedExpression::select(a.array, 0u32) + } + }, + ) + }) + .collect::, _>>() + .map(|v| v.into()) + } + }?; + + let inner_ty = res.0[0].get_type().0; + + Ok(ArrayExpressionInner::Value(res).annotate(inner_ty, array_ty.size)) + } + ArrayExpressionInner::Repeat(box e, box count) => { + match target_inner_ty.clone() { + Type::Int => Ok(ArrayExpressionInner::Repeat(box e, box count) + .annotate(Type::Int, array_ty.size)), + // try to convert the repeated element to the target type + t => TypedExpression::align_to_type(e, t) + .map(|e| { + ArrayExpressionInner::Repeat(box e, box count) + .annotate(target_inner_ty, array_ty.size) + }) + .map_err(|(e, _)| e), + } + } + a => { + if array_ty.ty.weak_eq(&target_inner_ty) { + Ok(a.annotate(*array_ty.ty, array_ty.size)) + } else { + Err(a.annotate(*array_ty.ty, array_ty.size).into()) + } + } + } + } +} + +impl<'ast, T> From for IntExpression<'ast, T> { + fn from(v: BigUint) -> Self { + IntExpression::Value(v) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use zokrates_field::Bn128Field; + + #[test] + fn field_from_int() { + let n: IntExpression = BigUint::from(42usize).into(); + let n_a: ArrayExpression = + ArrayExpressionInner::Value(vec![n.clone().into()].into()).annotate(Type::Int, 1u32); + let t: FieldElementExpression = Bn128Field::from(42).into(); + let t_a: ArrayExpression = + ArrayExpressionInner::Value(vec![t.clone().into()].into()) + .annotate(Type::FieldElement, 1u32); + let i: UExpression = 42u32.into(); + let c: BooleanExpression = true.into(); + + let expressions = vec![ + n.clone(), + n.clone() + n.clone(), + n.clone() - n.clone(), + n.clone() * n.clone(), + IntExpression::pow(n.clone(), n.clone()), + n.clone() / n.clone(), + IntExpression::if_else(c.clone(), n.clone(), n.clone()), + IntExpression::select(n_a.clone(), i.clone()), + ]; + + let expected = vec![ + t.clone(), + t.clone() + t.clone(), + t.clone() - t.clone(), + t.clone() * t.clone(), + FieldElementExpression::pow(t.clone(), i.clone()), + t.clone() / t.clone(), + FieldElementExpression::if_else(c.clone(), t.clone(), t.clone()), + FieldElementExpression::select(t_a.clone(), i.clone()), + ]; + + assert_eq!( + expressions + .into_iter() + .map(|e| FieldElementExpression::try_from_int(e).unwrap()) + .collect::>(), + expected + ); + + let should_error = vec![ + BigUint::parse_bytes(b"99999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10).unwrap().into(), + IntExpression::xor(n.clone(), n.clone()), + IntExpression::or(n.clone(), n.clone()), + IntExpression::and(n.clone(), n.clone()), + IntExpression::left_shift(n.clone(), i.clone()), + IntExpression::right_shift(n.clone(), i.clone()), + IntExpression::not(n.clone()), + ]; + + for e in should_error + .into_iter() + .map(|e| FieldElementExpression::try_from_int(e)) + { + assert!(e.is_err()); + } + } + + #[test] + fn uint_from_int() { + let n: IntExpression = BigUint::from(42usize).into(); + let n_a: ArrayExpression = + ArrayExpressionInner::Value(vec![n.clone().into()].into()).annotate(Type::Int, 1u32); + let t: UExpression = 42u32.into(); + let t_a: ArrayExpression = + ArrayExpressionInner::Value(vec![t.clone().into()].into()) + .annotate(Type::Uint(UBitwidth::B32), 1u32); + let i: UExpression = 0u32.into(); + let c: BooleanExpression = true.into(); + + let expressions = vec![ + n.clone(), + n.clone() + n.clone(), + IntExpression::xor(n.clone(), n.clone()), + IntExpression::or(n.clone(), n.clone()), + IntExpression::and(n.clone(), n.clone()), + n.clone() - n.clone(), + n.clone() * n.clone(), + n.clone() / n.clone(), + n.clone() % n.clone(), + IntExpression::left_shift(n.clone(), i.clone()), + IntExpression::right_shift(n.clone(), i.clone()), + !n.clone(), + IntExpression::if_else(c.clone(), n.clone(), n.clone()), + IntExpression::select(n_a.clone(), i.clone()), + ]; + + let expected = vec![ + t.clone(), + t.clone() + t.clone(), + UExpression::xor(t.clone(), t.clone()), + UExpression::or(t.clone(), t.clone()), + UExpression::and(t.clone(), t.clone()), + t.clone() - t.clone(), + t.clone() * t.clone(), + t.clone() / t.clone(), + t.clone() % t.clone(), + UExpression::left_shift(t.clone(), i.clone()), + UExpression::right_shift(t.clone(), i.clone()), + !t.clone(), + UExpression::if_else(c.clone(), t.clone(), t.clone()), + UExpression::select(t_a.clone(), i.clone()), + ]; + + for (r, e) in expressions + .into_iter() + .map(|e| UExpression::try_from_int(e, UBitwidth::B32).unwrap()) + .zip(expected) + { + assert_eq!(r, e); + } + + let should_error = vec![ + BigUint::parse_bytes(b"99999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10).unwrap().into(), + IntExpression::pow(n.clone(), n.clone()), + ]; + + for e in should_error + .into_iter() + .map(|e| UExpression::try_from_int(e, UBitwidth::B32)) + { + assert!(e.is_err()); + } + } +} diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index e158b7ae..32d2aa2a 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -8,28 +8,42 @@ pub mod abi; pub mod folder; pub mod identifier; +pub mod result_folder; +mod integer; mod parameter; pub mod types; mod uint; mod variable; pub use self::identifier::CoreIdentifier; -pub use self::parameter::Parameter; -pub use self::types::{Signature, StructType, Type, UBitwidth}; -pub use self::variable::Variable; -pub use crate::typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; +pub use self::parameter::{DeclarationParameter, GParameter}; +pub use self::types::{ + ConcreteFunctionKey, ConcreteSignature, ConcreteType, DeclarationFunctionKey, + DeclarationSignature, DeclarationType, GArrayType, GStructType, GType, GenericIdentifier, + Signature, StructType, Type, UBitwidth, +}; +use crate::typed_absy::types::ConcreteGenericsAssignment; + +pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::path::PathBuf; +pub use crate::typed_absy::integer::IntExpression; +pub use crate::typed_absy::uint::{bitwidth, UExpression, UExpressionInner, UMetadata}; + use crate::embed::FlatEmbed; -use crate::typed_absy::types::{FunctionKey, MemberId}; + use std::collections::HashMap; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fmt; + +pub use crate::typed_absy::types::{ArrayType, FunctionKey, MemberId}; + use zokrates_field::Field; pub use self::folder::Folder; use crate::typed_absy::abi::{Abi, AbiInput}; +use std::ops::{Add, Div, Mul, Sub}; pub use self::identifier::Identifier; @@ -43,7 +57,8 @@ pub type TypedModules<'ast, T> = HashMap>; /// # Remarks /// * It is the role of the semantic checker to make sure there are no duplicates for a given `FunctionKey` /// in a given `TypedModule`, hence the use of a HashMap -pub type TypedFunctionSymbols<'ast, T> = HashMap, TypedFunctionSymbol<'ast, T>>; +pub type TypedFunctionSymbols<'ast, T> = + HashMap, TypedFunctionSymbol<'ast, T>>; /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug, Clone)] @@ -52,6 +67,12 @@ pub struct TypedProgram<'ast, T> { pub main: TypedModuleId, } +impl<'ast, T> TypedProgram<'ast, T> { + pub fn main_function(&self) -> TypedFunction<'ast, T> { + unimplemented!() + } +} + impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn abi(&self) -> Abi { let main = self.modules[&self.main] @@ -69,13 +90,24 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { inputs: main .arguments .iter() - .map(|p| AbiInput { - public: !p.private, - name: p.id.id.to_string(), - ty: p.id._type.clone(), + .map(|p| { + types::ConcreteType::try_from(types::Type::::from(p.id._type.clone())) + .map(|ty| AbiInput { + public: !p.private, + name: p.id.id.to_string(), + ty, + }) + .unwrap() + }) + .collect(), + outputs: main + .signature + .outputs + .iter() + .map(|ty| { + types::ConcreteType::try_from(types::Type::::from(ty.clone())).unwrap() }) .collect(), - outputs: main.signature.outputs.clone(), } } } @@ -96,7 +128,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedProgram<'ast, T> { writeln!(f, "{}", "-".repeat(100))?; writeln!(f, "{}", module)?; writeln!(f, "{}", "-".repeat(100))?; - writeln!(f, "")?; + writeln!(f)?; } write!(f, "") } @@ -112,7 +144,7 @@ pub struct TypedModule<'ast, T> { #[derive(Clone, PartialEq)] pub enum TypedFunctionSymbol<'ast, T> { Here(TypedFunction<'ast, T>), - There(FunctionKey<'ast>, TypedModuleId), + There(DeclarationFunctionKey<'ast>), Flat(FlatEmbed), } @@ -121,24 +153,26 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunctionSymbol<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { TypedFunctionSymbol::Here(s) => write!(f, "Here({:?})", s), - TypedFunctionSymbol::There(key, module) => write!(f, "There({:?}, {:?})", key, module), + TypedFunctionSymbol::There(key) => write!(f, "There({:?})", key), TypedFunctionSymbol::Flat(s) => write!(f, "Flat({:?})", s), } } } impl<'ast, T: Field> TypedFunctionSymbol<'ast, T> { - pub fn signature<'a>(&'a self, modules: &'a TypedModules) -> Signature { + pub fn signature<'a>( + &'a self, + modules: &'a TypedModules<'ast, T>, + ) -> DeclarationSignature<'ast> { match self { TypedFunctionSymbol::Here(f) => f.signature.clone(), - TypedFunctionSymbol::There(key, module_id) => modules - .get(module_id) + TypedFunctionSymbol::There(key) => modules + .get(&key.module) .unwrap() .functions .get(key) .unwrap() - .signature(&modules) - .clone(), + .signature(&modules), TypedFunctionSymbol::Flat(flat_fun) => flat_fun.signature(), } } @@ -151,10 +185,10 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedModule<'ast, T> { .iter() .map(|(key, symbol)| match symbol { TypedFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), - TypedFunctionSymbol::There(ref fun_key, ref module_id) => format!( + TypedFunctionSymbol::There(ref fun_key) => format!( "import {} from \"{}\" as {} // with signature {}", fun_key.id, - module_id.display(), + fun_key.module.display(), key.id, key.signature ), @@ -182,18 +216,30 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedModule<'ast, T> { } /// A typed function -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Hash)] pub struct TypedFunction<'ast, T> { /// Arguments of the function - pub arguments: Vec>, + pub arguments: Vec>, /// Vector of statements that are executed when running the function pub statements: Vec>, /// function signature - pub signature: Signature, + pub signature: DeclarationSignature<'ast>, } impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if !self.signature.generics.is_empty() { + write!( + f, + "<{}>", + self.signature + .generics + .iter() + .map(|g| g.as_ref().unwrap().to_string()) + .collect::>() + .join(", ") + )?; + } write!( f, "({})", @@ -211,7 +257,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { 0 => "".into(), 1 => format!(" -> {}", self.signature.outputs[0]), _ => format!( - "{}", + " -> ({})", self.signature .outputs .iter() @@ -222,11 +268,21 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { } )?; - writeln!(f, "")?; + writeln!(f)?; + + let mut tab = 0; for s in &self.statements { - s.fmt_indented(f, 1)?; - writeln!(f, "")?; + if let TypedStatement::PopCallLog = s { + tab -= 1; + }; + + s.fmt_indented(f, 1 + tab)?; + writeln!(f)?; + + if let TypedStatement::PushCallLog(..) = s { + tab += 1; + }; } Ok(()) @@ -237,7 +293,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "TypedFunction(arguments: {:?}, ...):\n{}", + "TypedFunction(signature: {:?}, arguments: {:?}, ...):\n{}", + self.signature, self.arguments, self.statements .iter() @@ -251,22 +308,89 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedFunction<'ast, T> { /// Something we can assign to. #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedAssignee<'ast, T> { - Identifier(Variable<'ast>), - Select( - Box>, - Box>, - ), + Identifier(Variable<'ast, T>), + Select(Box>, Box>), Member(Box>, MemberId), } -impl<'ast, T> From> for TypedAssignee<'ast, T> { - fn from(v: Variable<'ast>) -> Self { +#[derive(Clone, PartialEq, Hash, Eq, Debug)] +pub struct TypedSpread<'ast, T> { + pub array: ArrayExpression<'ast, T>, +} + +impl<'ast, T> From> for TypedSpread<'ast, T> { + fn from(array: ArrayExpression<'ast, T>) -> TypedSpread<'ast, T> { + TypedSpread { array } + } +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedSpread<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "...{}", self.array) + } +} + +#[derive(Clone, PartialEq, Hash, Eq, Debug)] +pub enum TypedExpressionOrSpread<'ast, T> { + Expression(TypedExpression<'ast, T>), + Spread(TypedSpread<'ast, T>), +} + +impl<'ast, T: Clone> TypedExpressionOrSpread<'ast, T> { + pub fn size(&self) -> UExpression<'ast, T> { + match self { + TypedExpressionOrSpread::Expression(..) => 1u32.into(), + TypedExpressionOrSpread::Spread(s) => s.array.size(), + } + } +} + +impl<'ast, T> TryFrom> for TypedExpression<'ast, T> { + type Error = (); + + fn try_from( + e: TypedExpressionOrSpread<'ast, T>, + ) -> Result, Self::Error> { + if let TypedExpressionOrSpread::Expression(e) = e { + Ok(e) + } else { + Err(()) + } + } +} + +impl<'ast, T, U: Into>> From for TypedExpressionOrSpread<'ast, T> { + fn from(e: U) -> Self { + TypedExpressionOrSpread::Expression(e.into()) + } +} + +impl<'ast, T: Clone> TypedExpressionOrSpread<'ast, T> { + pub fn get_type(&self) -> (Type<'ast, T>, UExpression<'ast, T>) { + match self { + TypedExpressionOrSpread::Expression(e) => (e.get_type(), 1u32.into()), + TypedExpressionOrSpread::Spread(s) => (s.array.inner_type().clone(), s.array.size()), + } + } +} + +impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionOrSpread<'ast, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TypedExpressionOrSpread::Expression(e) => write!(f, "{}", e), + TypedExpressionOrSpread::Spread(s) => write!(f, "{}", s), + } + } +} + +impl<'ast, T> From> for TypedAssignee<'ast, T> { + fn from(v: Variable<'ast, T>) -> Self { TypedAssignee::Identifier(v) } } -impl<'ast, T> Typed for TypedAssignee<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T: Clone> Typed<'ast, T> for TypedAssignee<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { match *self { TypedAssignee::Identifier(ref v) => v.get_type(), TypedAssignee::Select(ref a, _) => { @@ -295,7 +419,7 @@ impl<'ast, T> Typed for TypedAssignee<'ast, T> { impl<'ast, T: fmt::Debug> fmt::Debug for TypedAssignee<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedAssignee::Identifier(ref s) => write!(f, "{}", s.id), + TypedAssignee::Identifier(ref s) => write!(f, "{:?}", s.id), TypedAssignee::Select(ref a, ref e) => write!(f, "Select({:?}, {:?})", a, e), TypedAssignee::Member(ref s, ref m) => write!(f, "Member({:?}, {:?})", s, m), } @@ -313,19 +437,26 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedAssignee<'ast, T> { } /// A statement in a `TypedFunction` +#[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedStatement<'ast, T> { Return(Vec>), Definition(TypedAssignee<'ast, T>, TypedExpression<'ast, T>), - Declaration(Variable<'ast>), + Declaration(Variable<'ast, T>), Assertion(BooleanExpression<'ast, T>), For( - Variable<'ast>, - FieldElementExpression<'ast, T>, - FieldElementExpression<'ast, T>, + Variable<'ast, T>, + UExpression<'ast, T>, + UExpression<'ast, T>, Vec>, ), MultipleDefinition(Vec>, TypedExpressionList<'ast, T>), + // Aux + PushCallLog( + DeclarationFunctionKey<'ast>, + ConcreteGenericsAssignment<'ast>, + ), + PopCallLog, } impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> { @@ -341,21 +472,25 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedStatement<'ast, T> { } write!(f, ")") } - TypedStatement::Declaration(ref var) => write!(f, "Declaration({:?})", var), + TypedStatement::Declaration(ref var) => write!(f, "({:?})", var), TypedStatement::Definition(ref lhs, ref rhs) => { write!(f, "Definition({:?}, {:?})", lhs, rhs) } TypedStatement::Assertion(ref e) => write!(f, "Assertion({:?})", e), TypedStatement::For(ref var, ref start, ref stop, ref list) => { - write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?; + writeln!(f, "for {:?} in {:?}..{:?} do", var, start, stop)?; for l in list { - write!(f, "\t\t{:?}\n", l)?; + writeln!(f, "\t\t{:?}", l)?; } write!(f, "\tendfor") } TypedStatement::MultipleDefinition(ref lhs, ref rhs) => { write!(f, "MultipleDefinition({:?}, {:?})", lhs, rhs) } + TypedStatement::PushCallLog(ref key, ref generics) => { + write!(f, "PushCallLog({:?}, {:?})", key, generics) + } + TypedStatement::PopCallLog => write!(f, "PopCallLog"), } } } @@ -368,9 +503,9 @@ impl<'ast, T: fmt::Display> TypedStatement<'ast, T> { writeln!(f, "for {} in {}..{} do", variable, from, to)?; for s in statements { s.fmt_indented(f, depth + 1)?; - writeln!(f, "")?; + writeln!(f)?; } - writeln!(f, "{}endfor", "\t".repeat(depth)) + write!(f, "{}endfor", "\t".repeat(depth)) } s => write!(f, "{}{}", "\t".repeat(depth), s), } @@ -394,9 +529,9 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), TypedStatement::Assertion(ref e) => write!(f, "assert({})", e), TypedStatement::For(ref var, ref start, ref stop, ref list) => { - write!(f, "for {} in {}..{} do\n", var, start, stop)?; + writeln!(f, "for {} in {}..{} do", var, start, stop)?; for l in list { - write!(f, "\t\t{}\n", l)?; + writeln!(f, "\t\t{}", l)?; } write!(f, "\tendfor") } @@ -409,15 +544,24 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { } write!(f, " = {}", rhs) } + TypedStatement::PushCallLog(ref key, ref generics) => write!( + f, + "// PUSH CALL TO {}/{}::<{}>", + key.module.display(), + key.id, + generics, + ), + TypedStatement::PopCallLog => write!(f, "// POP CALL",), } } } -pub trait Typed { - fn get_type(&self) -> Type; +pub trait Typed<'ast, T> { + fn get_type(&self) -> Type<'ast, T>; } /// A typed expression +#[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedExpression<'ast, T> { Boolean(BooleanExpression<'ast, T>), @@ -425,6 +569,7 @@ pub enum TypedExpression<'ast, T> { Uint(UExpression<'ast, T>), Array(ArrayExpression<'ast, T>), Struct(StructExpression<'ast, T>), + Int(IntExpression<'ast, T>), } impl<'ast, T> From> for TypedExpression<'ast, T> { @@ -439,6 +584,12 @@ impl<'ast, T> From> for TypedExpression<'ast, T> } } +impl<'ast, T> From> for TypedExpression<'ast, T> { + fn from(e: IntExpression<'ast, T>) -> TypedExpression { + TypedExpression::Int(e) + } +} + impl<'ast, T> From> for TypedExpression<'ast, T> { fn from(e: UExpression<'ast, T>) -> TypedExpression { TypedExpression::Uint(e) @@ -465,6 +616,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpression<'ast, T> { TypedExpression::Uint(ref e) => write!(f, "{}", e), TypedExpression::Array(ref e) => write!(f, "{}", e), TypedExpression::Struct(ref s) => write!(f, "{}", s), + TypedExpression::Int(ref s) => write!(f, "{}", s), } } } @@ -477,6 +629,7 @@ impl<'ast, T: fmt::Debug> fmt::Debug for TypedExpression<'ast, T> { TypedExpression::Uint(ref e) => write!(f, "{:?}", e), TypedExpression::Array(ref e) => write!(f, "{:?}", e), TypedExpression::Struct(ref s) => write!(f, "{:?}", s), + TypedExpression::Int(ref s) => write!(f, "{:?}", s), } } } @@ -499,7 +652,8 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { StructExpressionInner::Identifier(ref var) => write!(f, "{}", var), StructExpressionInner::Value(ref values) => write!( f, - "{{{}}}", + "{} {{{}}}", + self.ty.name(), self.ty .iter() .map(|member| member.id.clone()) @@ -508,8 +662,23 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { .collect::>() .join(", ") ), - StructExpressionInner::FunctionCall(ref key, ref p) => { - write!(f, "{}(", key.id,)?; + StructExpressionInner::FunctionCall(ref key, ref generics, ref p) => { + write!(f, "{}", key.id,)?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -537,61 +706,74 @@ impl<'ast, T: fmt::Debug> fmt::Debug for StructExpression<'ast, T> { } } -impl<'ast, T> Typed for TypedExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T: Clone> Typed<'ast, T> for TypedExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { match *self { TypedExpression::Boolean(ref e) => e.get_type(), TypedExpression::FieldElement(ref e) => e.get_type(), TypedExpression::Array(ref e) => e.get_type(), TypedExpression::Uint(ref e) => e.get_type(), TypedExpression::Struct(ref s) => s.get_type(), + TypedExpression::Int(_) => Type::Int, } } } -impl<'ast, T> Typed for ArrayExpression<'ast, T> { - fn get_type(&self) -> Type { - Type::array(self.ty.clone(), self.size) +impl<'ast, T: Clone> Typed<'ast, T> for ArrayExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { + Type::array(*self.ty.clone()) } } -impl<'ast, T> Typed for StructExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T: Clone> Typed<'ast, T> for StructExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::Struct(self.ty.clone()) } } -impl<'ast, T> Typed for FieldElementExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T: Clone> Typed<'ast, T> for FieldElementExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::FieldElement } } -impl<'ast, T> Typed for UExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T: Clone> Typed<'ast, T> for UExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::Uint(self.bitwidth) } } -impl<'ast, T> Typed for BooleanExpression<'ast, T> { - fn get_type(&self) -> Type { +impl<'ast, T: Clone> Typed<'ast, T> for BooleanExpression<'ast, T> { + fn get_type(&self) -> Type<'ast, T> { Type::Boolean } } -pub trait MultiTyped { - fn get_types(&self) -> &Vec; +pub trait MultiTyped<'ast, T> { + fn get_types(&self) -> &Vec>; } #[derive(Clone, PartialEq, Hash, Eq)] pub enum TypedExpressionList<'ast, T> { - FunctionCall(FunctionKey<'ast>, Vec>, Vec), + FunctionCall( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, + Vec>, + ), + EmbedCall( + FlatEmbed, + Vec, + Vec>, + Vec>, + ), } -impl<'ast, T> MultiTyped for TypedExpressionList<'ast, T> { - fn get_types(&self) -> &Vec { +impl<'ast, T> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> { + fn get_types(&self) -> &Vec> { match *self { - TypedExpressionList::FunctionCall(_, _, ref types) => types, + TypedExpressionList::FunctionCall(_, _, _, ref types) => types, + TypedExpressionList::EmbedCall(_, _, _, ref types) => types, } } } @@ -619,7 +801,7 @@ pub enum FieldElementExpression<'ast, T> { ), Pow( Box>, - Box>, + Box>, ), IfElse( Box>, @@ -628,12 +810,50 @@ pub enum FieldElementExpression<'ast, T> { ), Neg(Box>), Pos(Box>), - FunctionCall(FunctionKey<'ast>, Vec>), - Member(Box>, MemberId), - Select( - Box>, - Box>, + FunctionCall( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, ), + Member(Box>, MemberId), + Select(Box>, Box>), +} +impl<'ast, T> Add for FieldElementExpression<'ast, T> { + type Output = Self; + + fn add(self, other: Self) -> Self { + FieldElementExpression::Add(box self, box other) + } +} + +impl<'ast, T> Sub for FieldElementExpression<'ast, T> { + type Output = Self; + + fn sub(self, other: Self) -> Self { + FieldElementExpression::Sub(box self, box other) + } +} + +impl<'ast, T> Mul for FieldElementExpression<'ast, T> { + type Output = Self; + + fn mul(self, other: Self) -> Self { + FieldElementExpression::Mult(box self, box other) + } +} + +impl<'ast, T> Div for FieldElementExpression<'ast, T> { + type Output = Self; + + fn div(self, other: Self) -> Self { + FieldElementExpression::Div(box self, box other) + } +} + +impl<'ast, T> FieldElementExpression<'ast, T> { + pub fn pow(self, other: UExpression<'ast, T>) -> Self { + FieldElementExpression::Pow(box self, box other) + } } impl<'ast, T> From for FieldElementExpression<'ast, T> { @@ -647,14 +867,26 @@ impl<'ast, T> From for FieldElementExpression<'ast, T> { pub enum BooleanExpression<'ast, T> { Identifier(Identifier<'ast>), Value(bool), - Lt( + FieldLt( Box>, Box>, ), - Le( + FieldLe( Box>, Box>, ), + FieldGe( + Box>, + Box>, + ), + FieldGt( + Box>, + Box>, + ), + UintLt(Box>, Box>), + UintLe(Box>, Box>), + UintGe(Box>, Box>), + UintGt(Box>, Box>), FieldEq( Box>, Box>, @@ -669,14 +901,6 @@ pub enum BooleanExpression<'ast, T> { Box>, ), UintEq(Box>, Box>), - Ge( - Box>, - Box>, - ), - Gt( - Box>, - Box>, - ), Or( Box>, Box>, @@ -692,11 +916,18 @@ pub enum BooleanExpression<'ast, T> { Box>, ), Member(Box>, MemberId), - FunctionCall(FunctionKey<'ast>, Vec>), - Select( - Box>, - Box>, + FunctionCall( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, ), + Select(Box>, Box>), +} + +impl<'ast, T> From for BooleanExpression<'ast, T> { + fn from(b: bool) -> Self { + BooleanExpression::Value(b) + } } /// An expression of type `array` @@ -706,45 +937,132 @@ pub enum BooleanExpression<'ast, T> { /// type checking #[derive(Clone, PartialEq, Hash, Eq)] pub struct ArrayExpression<'ast, T> { - size: usize, - ty: Type, + ty: Box>, inner: ArrayExpressionInner<'ast, T>, } +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct ArrayValue<'ast, T>(pub Vec>); + +impl<'ast, T> From>> for ArrayValue<'ast, T> { + fn from(array: Vec>) -> Self { + Self(array) + } +} + +impl<'ast, T> IntoIterator for ArrayValue<'ast, T> { + type Item = TypedExpressionOrSpread<'ast, T>; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'ast, T: Clone> ArrayValue<'ast, T> { + fn expression_at_aux + Into>>( + v: TypedExpressionOrSpread<'ast, T>, + ) -> Vec>> { + match v { + TypedExpressionOrSpread::Expression(e) => vec![Some(e.clone())], + TypedExpressionOrSpread::Spread(s) => match s.array.size().into_inner() { + UExpressionInner::Value(size) => { + let array_ty = s.array.get_array_type().clone(); + + match s.array.into_inner() { + ArrayExpressionInner::Value(v) => v + .into_iter() + .flat_map(Self::expression_at_aux::) + .collect(), + a => (0..size) + .map(|i| { + Some( + U::select( + a.clone() + .annotate(*array_ty.ty.clone(), array_ty.size.clone()), + i as u32, + ) + .into(), + ) + }) + .collect(), + } + } + _ => vec![None], + }, + } + } + + pub fn expression_at + Into>>( + &self, + index: usize, + ) -> Option> { + self.0 + .iter() + .map(|v| Self::expression_at_aux::(v.clone())) + .flatten() + .take_while(|e| e.is_some()) + .map(|e| e.unwrap()) + .nth(index) + } +} + +impl<'ast, T> ArrayValue<'ast, T> { + fn iter(&self) -> std::slice::Iter> { + self.0.iter() + } +} + +impl<'ast, T> std::iter::FromIterator> for ArrayValue<'ast, T> { + fn from_iter>>(iter: I) -> Self { + Self(iter.into_iter().collect()) + } +} + #[derive(Clone, PartialEq, Hash, Eq)] pub enum ArrayExpressionInner<'ast, T> { Identifier(Identifier<'ast>), - Value(Vec>), - FunctionCall(FunctionKey<'ast>, Vec>), + Value(ArrayValue<'ast, T>), + FunctionCall( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, + ), IfElse( Box>, Box>, Box>, ), Member(Box>, MemberId), - Select( + Select(Box>, Box>), + Slice( Box>, - Box>, + Box>, + Box>, ), + Repeat(Box>, Box>), } impl<'ast, T> ArrayExpressionInner<'ast, T> { - pub fn annotate(self, ty: Type, size: usize) -> ArrayExpression<'ast, T> { + pub fn annotate>>( + self, + ty: Type<'ast, T>, + size: S, + ) -> ArrayExpression<'ast, T> { ArrayExpression { - size, - ty, + ty: box (ty, size.into()).into(), inner: self, } } } -impl<'ast, T> ArrayExpression<'ast, T> { - pub fn inner_type(&self) -> &Type { - &self.ty +impl<'ast, T: Clone> ArrayExpression<'ast, T> { + pub fn inner_type(&self) -> &Type<'ast, T> { + &self.ty.ty } - pub fn size(&self) -> usize { - self.size + pub fn size(&self) -> UExpression<'ast, T> { + self.ty.size.clone() } pub fn as_inner(&self) -> &ArrayExpressionInner<'ast, T> { @@ -758,16 +1076,41 @@ impl<'ast, T> ArrayExpression<'ast, T> { pub fn into_inner(self) -> ArrayExpressionInner<'ast, T> { self.inner } + + pub fn get_array_type(&self) -> ArrayType<'ast, T> { + ArrayType { + size: self.size(), + ty: box self.inner_type().clone(), + } + } } #[derive(Clone, PartialEq, Hash, Eq)] pub struct StructExpression<'ast, T> { - ty: StructType, + ty: StructType<'ast, T>, inner: StructExpressionInner<'ast, T>, } +impl<'ast, T: Field> StructExpression<'ast, T> { + pub fn try_from_typed( + e: TypedExpression<'ast, T>, + target_struct_ty: StructType<'ast, T>, + ) -> Result> { + match e { + TypedExpression::Struct(e) => { + if e.ty() == &target_struct_ty { + Ok(e) + } else { + Err(TypedExpression::Struct(e)) + } + } + e => Err(e), + } + } +} + impl<'ast, T> StructExpression<'ast, T> { - pub fn ty(&self) -> &StructType { + pub fn ty(&self) -> &StructType<'ast, T> { &self.ty } @@ -788,21 +1131,22 @@ impl<'ast, T> StructExpression<'ast, T> { pub enum StructExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), - FunctionCall(FunctionKey<'ast>, Vec>), + FunctionCall( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, + ), IfElse( Box>, Box>, Box>, ), Member(Box>, MemberId), - Select( - Box>, - Box>, - ), + Select(Box>, Box>), } impl<'ast, T> StructExpressionInner<'ast, T> { - pub fn annotate(self, ty: StructType) -> StructExpression<'ast, T> { + pub fn annotate(self, ty: StructType<'ast, T>) -> StructExpression<'ast, T> { StructExpression { ty, inner: self } } } @@ -857,6 +1201,17 @@ impl<'ast, T> TryFrom> for ArrayExpression<'ast, T> { } } +impl<'ast, T> TryFrom> for IntExpression<'ast, T> { + type Error = (); + + fn try_from(te: TypedExpression<'ast, T>) -> Result, Self::Error> { + match te { + TypedExpression::Int(e) => Ok(e), + _ => Err(()), + } + } +} + impl<'ast, T> TryFrom> for StructExpression<'ast, T> { type Error = (); @@ -871,7 +1226,7 @@ impl<'ast, T> TryFrom> for StructExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldElementExpression::Number(ref i) => write!(f, "{}", i), + FieldElementExpression::Number(ref i) => write!(f, "{}f", i), FieldElementExpression::Identifier(ref var) => write!(f, "{}", var), FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), FieldElementExpression::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), @@ -887,8 +1242,23 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { condition, consequent, alternative ) } - FieldElementExpression::FunctionCall(ref k, ref p) => { - write!(f, "{}(", k.id,)?; + FieldElementExpression::FunctionCall(ref k, ref generics, ref p) => { + write!(f, "{}", k.id,)?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -906,7 +1276,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.inner { - UExpressionInner::Value(ref v) => write!(f, "0x{:x}", v), + UExpressionInner::Value(ref v) => write!(f, "{}", v), UExpressionInner::Identifier(ref var) => write!(f, "{}", var), UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), UExpressionInner::And(ref lhs, ref rhs) => write!(f, "({} & {})", lhs, rhs), @@ -914,6 +1284,9 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Xor(ref lhs, ref rhs) => write!(f, "({} ^ {})", lhs, rhs), UExpressionInner::Sub(ref lhs, ref rhs) => write!(f, "({} - {})", lhs, rhs), UExpressionInner::Mult(ref lhs, ref rhs) => write!(f, "({} * {})", lhs, rhs), + UExpressionInner::FloorSub(ref lhs, ref rhs) => { + write!(f, "(FLOOR_SUB({}, {}))", lhs, rhs) + } UExpressionInner::Div(ref lhs, ref rhs) => write!(f, "({} / {})", lhs, rhs), UExpressionInner::Rem(ref lhs, ref rhs) => write!(f, "({} % {})", lhs, rhs), UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), @@ -922,8 +1295,23 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Neg(ref e) => write!(f, "(-{})", e), UExpressionInner::Pos(ref e) => write!(f, "(+{})", e), UExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), - UExpressionInner::FunctionCall(ref k, ref p) => { - write!(f, "{}(", k.id,)?; + UExpressionInner::FunctionCall(ref k, ref generics, ref p) => { + write!(f, "{}", k.id,)?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -946,21 +1334,40 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { BooleanExpression::Identifier(ref var) => write!(f, "{}", var), - BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), - BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), + BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), + BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), + BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::StructEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), - BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), - BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs), BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), BooleanExpression::Value(b) => write!(f, "{}", b), - BooleanExpression::FunctionCall(ref k, ref p) => { - write!(f, "{}(", k.id,)?; + BooleanExpression::FunctionCall(ref k, ref generics, ref p) => { + write!(f, "{}", k.id,)?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -993,8 +1400,23 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { .collect::>() .join(", ") ), - ArrayExpressionInner::FunctionCall(ref key, ref p) => { - write!(f, "{}(", key.id,)?; + ArrayExpressionInner::FunctionCall(ref key, ref generics, ref p) => { + write!(f, "{}", key.id,)?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -1010,6 +1432,12 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { ), ArrayExpressionInner::Member(ref s, ref id) => write!(f, "{}.{}", s, id), ArrayExpressionInner::Select(ref id, ref index) => write!(f, "{}[{}]", id, index), + ArrayExpressionInner::Slice(ref a, ref from, ref to) => { + write!(f, "{}[{}..{}]", a, from, to) + } + ArrayExpressionInner::Repeat(ref e, ref count) => { + write!(f, "[{}; {}]", e, count) + } } } } @@ -1024,8 +1452,30 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> { "IfElse({:?}, {:?}, {:?})", condition, consequent, alternative ), - BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "Lt({:?}, {:?})", lhs, rhs), - BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "Le({:?}, {:?})", lhs, rhs), + BooleanExpression::FieldLt(ref lhs, ref rhs) => { + write!(f, "FieldLt({:?}, {:?})", lhs, rhs) + } + BooleanExpression::FieldLe(ref lhs, ref rhs) => { + write!(f, "FieldLe({:?}, {:?})", lhs, rhs) + } + BooleanExpression::FieldGe(ref lhs, ref rhs) => { + write!(f, "FieldGe({:?}, {:?})", lhs, rhs) + } + BooleanExpression::FieldGt(ref lhs, ref rhs) => { + write!(f, "FieldGt({:?}, {:?})", lhs, rhs) + } + BooleanExpression::UintLt(ref lhs, ref rhs) => { + write!(f, "UintLt({:?}, {:?})", lhs, rhs) + } + BooleanExpression::UintLe(ref lhs, ref rhs) => { + write!(f, "UintLe({:?}, {:?})", lhs, rhs) + } + BooleanExpression::UintGe(ref lhs, ref rhs) => { + write!(f, "UintGe({:?}, {:?})", lhs, rhs) + } + BooleanExpression::UintGt(ref lhs, ref rhs) => { + write!(f, "UintGt({:?}, {:?})", lhs, rhs) + } BooleanExpression::FieldEq(ref lhs, ref rhs) => { write!(f, "FieldEq({:?}, {:?})", lhs, rhs) } @@ -1041,12 +1491,10 @@ impl<'ast, T: fmt::Debug> fmt::Debug for BooleanExpression<'ast, T> { BooleanExpression::UintEq(ref lhs, ref rhs) => { write!(f, "UintEq({:?}, {:?})", lhs, rhs) } - BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "Ge({:?}, {:?})", lhs, rhs), - BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "Gt({:?}, {:?})", lhs, rhs), BooleanExpression::And(ref lhs, ref rhs) => write!(f, "And({:?}, {:?})", lhs, rhs), BooleanExpression::Not(ref exp) => write!(f, "Not({:?})", exp), - BooleanExpression::FunctionCall(ref i, ref p) => { - write!(f, "FunctionCall({:?}, (", i)?; + BooleanExpression::FunctionCall(ref i, ref g, ref p) => { + write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } @@ -1082,8 +1530,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for FieldElementExpression<'ast, T> { condition, consequent, alternative ) } - FieldElementExpression::FunctionCall(ref i, ref p) => { - write!(f, "FunctionCall({:?}, (", i)?; + FieldElementExpression::FunctionCall(ref i, ref g, ref p) => { + write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } @@ -1102,8 +1550,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ArrayExpressionInner<'ast, T> { match *self { ArrayExpressionInner::Identifier(ref var) => write!(f, "Identifier({:?})", var), ArrayExpressionInner::Value(ref values) => write!(f, "Value({:?})", values), - ArrayExpressionInner::FunctionCall(ref i, ref p) => { - write!(f, "FunctionCall({:?}, (", i)?; + ArrayExpressionInner::FunctionCall(ref i, ref g, ref p) => { + write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } @@ -1115,8 +1563,14 @@ impl<'ast, T: fmt::Debug> fmt::Debug for ArrayExpressionInner<'ast, T> { ArrayExpressionInner::Member(ref struc, ref id) => { write!(f, "Member({:?}, {:?})", struc, id) } - ArrayExpressionInner::Select(ref id, ref index) => { - write!(f, "Select({:?}, {:?})", id, index) + ArrayExpressionInner::Select(ref array, ref index) => { + write!(f, "Select({:?}, {:?})", array, index) + } + ArrayExpressionInner::Slice(ref array, ref from, ref to) => { + write!(f, "Slice({:?}, {:?}, {:?})", array, from, to) + } + ArrayExpressionInner::Repeat(ref e, ref count) => { + write!(f, "Repeat({:?}, {:?})", e, count) } } } @@ -1127,8 +1581,8 @@ impl<'ast, T: fmt::Debug> fmt::Debug for StructExpressionInner<'ast, T> { match *self { StructExpressionInner::Identifier(ref var) => write!(f, "{:?}", var), StructExpressionInner::Value(ref values) => write!(f, "{:?}", values), - StructExpressionInner::FunctionCall(ref i, ref p) => { - write!(f, "FunctionCall({:?}, (", i)?; + StructExpressionInner::FunctionCall(ref i, ref g, ref p) => { + write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } @@ -1152,8 +1606,45 @@ impl<'ast, T: fmt::Debug> fmt::Debug for StructExpressionInner<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedExpressionList::FunctionCall(ref key, ref p, _) => { - write!(f, "{}(", key.id,)?; + TypedExpressionList::FunctionCall(ref k, ref generics, ref p, _) => { + write!(f, "{}", k.id,)?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; + for (i, param) in p.iter().enumerate() { + write!(f, "{}", param)?; + if i < p.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, ")") + } + TypedExpressionList::EmbedCall(ref embed, ref generics, ref p, _) => { + write!(f, "{}", embed.id())?; + if !generics.is_empty() { + write!( + f, + "::<{}>", + generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -1169,11 +1660,33 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> { impl<'ast, T: fmt::Debug> fmt::Debug for TypedExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - TypedExpressionList::FunctionCall(ref i, ref p, _) => { - write!(f, "FunctionCall({:?}, (", i)?; + TypedExpressionList::FunctionCall(ref i, ref g, ref p, _) => { + write!(f, "FunctionCall({:?}, {:?}, (", g, i)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } + TypedExpressionList::EmbedCall(ref embed, ref g, ref p, _) => { + write!(f, "EmbedCall({:?}, {:?}, (", g, embed)?; + f.debug_list().entries(p.iter()).finish()?; + write!(f, ")") + } + } + } +} + +// Variable to TypedExpression conversion + +impl<'ast, T: Field> From> for TypedExpression<'ast, T> { + fn from(v: Variable<'ast, T>) -> Self { + match v.get_type() { + Type::FieldElement => FieldElementExpression::Identifier(v.id).into(), + Type::Boolean => BooleanExpression::Identifier(v.id).into(), + Type::Array(ty) => ArrayExpressionInner::Identifier(v.id) + .annotate(*ty.ty, ty.size) + .into(), + Type::Struct(ty) => StructExpressionInner::Identifier(v.id).annotate(ty).into(), + Type::Uint(w) => UExpressionInner::Identifier(v.id).annotate(w).into(), + Type::Int => unreachable!(), } } } @@ -1195,6 +1708,16 @@ impl<'ast, T> IfElse<'ast, T> for FieldElementExpression<'ast, T> { } } +impl<'ast, T> IfElse<'ast, T> for IntExpression<'ast, T> { + fn if_else( + condition: BooleanExpression<'ast, T>, + consequence: Self, + alternative: Self, + ) -> Self { + IntExpression::IfElse(box condition, box consequence, box alternative) + } +} + impl<'ast, T> IfElse<'ast, T> for BooleanExpression<'ast, T> { fn if_else( condition: BooleanExpression<'ast, T>, @@ -1217,7 +1740,7 @@ impl<'ast, T> IfElse<'ast, T> for UExpression<'ast, T> { } } -impl<'ast, T> IfElse<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: Clone> IfElse<'ast, T> for ArrayExpression<'ast, T> { fn if_else( condition: BooleanExpression<'ast, T>, consequence: Self, @@ -1230,7 +1753,7 @@ impl<'ast, T> IfElse<'ast, T> for ArrayExpression<'ast, T> { } } -impl<'ast, T> IfElse<'ast, T> for StructExpression<'ast, T> { +impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> { fn if_else( condition: BooleanExpression<'ast, T>, consequence: Self, @@ -1242,51 +1765,70 @@ impl<'ast, T> IfElse<'ast, T> for StructExpression<'ast, T> { } pub trait Select<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self; + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self; } impl<'ast, T> Select<'ast, T> for FieldElementExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { - FieldElementExpression::Select(box array, box index) + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { + FieldElementExpression::Select(box array, box index.into()) + } +} + +impl<'ast, T> Select<'ast, T> for IntExpression<'ast, T> { + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { + IntExpression::Select(box array, box index.into()) } } impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { - BooleanExpression::Select(box array, box index) + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { + BooleanExpression::Select(box array, box index.into()) } } -impl<'ast, T> Select<'ast, T> for UExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { +impl<'ast, T: Clone> Select<'ast, T> for TypedExpression<'ast, T> { + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { + match *array.get_array_type().ty { + Type::Array(..) => ArrayExpression::select(array, index).into(), + Type::Struct(..) => StructExpression::select(array, index).into(), + Type::FieldElement => FieldElementExpression::select(array, index).into(), + Type::Boolean => BooleanExpression::select(array, index).into(), + Type::Int => IntExpression::select(array, index).into(), + Type::Uint(..) => UExpression::select(array, index).into(), + } + } +} + +impl<'ast, T: Clone> Select<'ast, T> for UExpression<'ast, T> { + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { let bitwidth = match array.inner_type().clone() { Type::Uint(bitwidth) => bitwidth, _ => unreachable!(), }; - UExpressionInner::Select(box array, box index).annotate(bitwidth) + UExpressionInner::Select(box array, box index.into()).annotate(bitwidth) } } -impl<'ast, T> Select<'ast, T> for ArrayExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { +impl<'ast, T: Clone> Select<'ast, T> for ArrayExpression<'ast, T> { + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { let (ty, size) = match array.inner_type() { Type::Array(array_type) => (array_type.ty.clone(), array_type.size.clone()), _ => unreachable!(), }; - ArrayExpressionInner::Select(box array, box index).annotate(*ty, size) + ArrayExpressionInner::Select(box array, box index.into()).annotate(*ty, size) } } -impl<'ast, T> Select<'ast, T> for StructExpression<'ast, T> { - fn select(array: ArrayExpression<'ast, T>, index: FieldElementExpression<'ast, T>) -> Self { +impl<'ast, T: Clone> Select<'ast, T> for StructExpression<'ast, T> { + fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { let members = match array.inner_type().clone() { Type::Struct(members) => members, _ => unreachable!(), }; - StructExpressionInner::Select(box array, box index).annotate(members) + StructExpressionInner::Select(box array, box index.into()).annotate(members) } } @@ -1306,15 +1848,16 @@ impl<'ast, T> Member<'ast, T> for BooleanExpression<'ast, T> { } } -impl<'ast, T> Member<'ast, T> for UExpression<'ast, T> { +impl<'ast, T: Clone> Member<'ast, T> for UExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { let members = s.ty().clone(); let ty = members - .into_iter() + .iter() .find(|member| *member.id == member_id) .unwrap() - .ty; + .ty + .clone(); let bitwidth = match *ty { Type::Uint(bitwidth) => bitwidth, @@ -1325,15 +1868,16 @@ impl<'ast, T> Member<'ast, T> for UExpression<'ast, T> { } } -impl<'ast, T> Member<'ast, T> for ArrayExpression<'ast, T> { +impl<'ast, T: Clone> Member<'ast, T> for ArrayExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { let members = s.ty().clone(); let ty = members - .into_iter() + .iter() .find(|member| *member.id == member_id) .unwrap() - .ty; + .ty + .clone(); let (ty, size) = match *ty { Type::Array(array_type) => (array_type.ty, array_type.size), @@ -1344,15 +1888,16 @@ impl<'ast, T> Member<'ast, T> for ArrayExpression<'ast, T> { } } -impl<'ast, T> Member<'ast, T> for StructExpression<'ast, T> { +impl<'ast, T: Clone> Member<'ast, T> for StructExpression<'ast, T> { fn member(s: StructExpression<'ast, T>, member_id: MemberId) -> Self { let members = s.ty().clone(); let ty = members - .into_iter() + .iter() .find(|member| *member.id == member_id) .unwrap() - .ty; + .ty + .clone(); let members = match *ty { Type::Struct(members) => members, @@ -1362,3 +1907,83 @@ impl<'ast, T> Member<'ast, T> for StructExpression<'ast, T> { StructExpressionInner::Member(box s, member_id).annotate(members) } } + +pub trait FunctionCall<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_type: Type<'ast, T>, + ) -> Self; +} + +impl<'ast, T: Field> FunctionCall<'ast, T> for FieldElementExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_type: Type<'ast, T>, + ) -> Self { + assert_eq!(output_type, Type::FieldElement); + FieldElementExpression::FunctionCall(key, generics, arguments) + } +} + +impl<'ast, T: Field> FunctionCall<'ast, T> for BooleanExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_type: Type<'ast, T>, + ) -> Self { + assert_eq!(output_type, Type::Boolean); + BooleanExpression::FunctionCall(key, generics, arguments) + } +} + +impl<'ast, T: Field> FunctionCall<'ast, T> for UExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_type: Type<'ast, T>, + ) -> Self { + let bitwidth = match output_type { + Type::Uint(bitwidth) => bitwidth, + _ => unreachable!(), + }; + UExpressionInner::FunctionCall(key, generics, arguments).annotate(bitwidth) + } +} + +impl<'ast, T: Field> FunctionCall<'ast, T> for ArrayExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_type: Type<'ast, T>, + ) -> Self { + let array_ty = match output_type { + Type::Array(array_ty) => array_ty, + _ => unreachable!(), + }; + ArrayExpressionInner::FunctionCall(key, generics, arguments) + .annotate(*array_ty.ty, array_ty.size) + } +} + +impl<'ast, T: Field> FunctionCall<'ast, T> for StructExpression<'ast, T> { + fn function_call( + key: DeclarationFunctionKey<'ast>, + generics: Vec>>, + arguments: Vec>, + output_type: Type<'ast, T>, + ) -> Self { + let struct_ty = match output_type { + Type::Struct(struct_ty) => struct_ty, + _ => unreachable!(), + }; + + StructExpressionInner::FunctionCall(key, generics, arguments).annotate(struct_ty) + } +} diff --git a/zokrates_core/src/typed_absy/parameter.rs b/zokrates_core/src/typed_absy/parameter.rs index 277ac537..6a41c0e2 100644 --- a/zokrates_core/src/typed_absy/parameter.rs +++ b/zokrates_core/src/typed_absy/parameter.rs @@ -1,30 +1,33 @@ -use crate::typed_absy::Variable; +use crate::typed_absy::types::Constant; +use crate::typed_absy::GVariable; use std::fmt; -#[derive(Clone, PartialEq)] -pub struct Parameter<'ast> { - pub id: Variable<'ast>, +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct GParameter<'ast, S> { + pub id: GVariable<'ast, S>, pub private: bool, } -impl<'ast> Parameter<'ast> { - #[cfg(test)] - pub fn private(v: Variable<'ast>) -> Self { - Parameter { +#[cfg(test)] +impl<'ast, S> From> for GParameter<'ast, S> { + fn from(v: GVariable<'ast, S>) -> Self { + GParameter { id: v, private: true, } } } -impl<'ast> fmt::Display for Parameter<'ast> { +pub type DeclarationParameter<'ast> = GParameter<'ast, Constant<'ast>>; + +impl<'ast, S: fmt::Display + Clone> fmt::Display for GParameter<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let visibility = if self.private { "private " } else { "" }; write!(f, "{}{} {}", visibility, self.id.get_type(), self.id.id) } } -impl<'ast> fmt::Debug for Parameter<'ast> { +impl<'ast, S: fmt::Debug> fmt::Debug for GParameter<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Parameter(variable: {:?})", self.id) } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs new file mode 100644 index 00000000..16aa40db --- /dev/null +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -0,0 +1,831 @@ +// Generic walk through a typed AST. Not mutating in place + +use crate::typed_absy::types::{ArrayType, StructMember, StructType}; +use crate::typed_absy::*; +use zokrates_field::Field; + +pub trait ResultFolder<'ast, T: Field>: Sized { + type Error; + + fn fold_program( + &mut self, + p: TypedProgram<'ast, T>, + ) -> Result, Self::Error> { + fold_program(self, p) + } + + fn fold_module( + &mut self, + p: TypedModule<'ast, T>, + ) -> Result, Self::Error> { + fold_module(self, p) + } + + fn fold_function_symbol( + &mut self, + s: TypedFunctionSymbol<'ast, T>, + ) -> Result, Self::Error> { + fold_function_symbol(self, s) + } + + fn fold_function( + &mut self, + f: TypedFunction<'ast, T>, + ) -> Result, Self::Error> { + fold_function(self, f) + } + + fn fold_parameter( + &mut self, + p: DeclarationParameter<'ast>, + ) -> Result, Self::Error> { + Ok(DeclarationParameter { + id: self.fold_declaration_variable(p.id)?, + ..p + }) + } + + fn fold_name(&mut self, n: Identifier<'ast>) -> Result, Self::Error> { + Ok(n) + } + + fn fold_variable(&mut self, v: Variable<'ast, T>) -> Result, Self::Error> { + Ok(Variable { + id: self.fold_name(v.id)?, + _type: self.fold_type(v._type)?, + }) + } + + fn fold_declaration_variable( + &mut self, + v: DeclarationVariable<'ast>, + ) -> Result, Self::Error> { + Ok(DeclarationVariable { + id: self.fold_name(v.id)?, + _type: self.fold_declaration_type(v._type)?, + }) + } + + fn fold_type(&mut self, t: Type<'ast, T>) -> Result, Self::Error> { + use self::GType::*; + + match t { + Array(array_type) => Ok(Array(self.fold_array_type(array_type)?)), + Struct(struct_type) => Ok(Struct(self.fold_struct_type(struct_type)?)), + t => Ok(t), + } + } + + fn fold_array_type( + &mut self, + t: ArrayType<'ast, T>, + ) -> Result, Self::Error> { + Ok(ArrayType { + ty: box self.fold_type(*t.ty)?, + size: self.fold_uint_expression(t.size)?, + }) + } + + fn fold_struct_type( + &mut self, + t: StructType<'ast, T>, + ) -> Result, Self::Error> { + Ok(StructType { + members: t + .members + .into_iter() + .map(|m| { + let id = m.id; + self.fold_type(*m.ty) + .map(|ty| StructMember { ty: box ty, id }) + }) + .collect::>()?, + ..t + }) + } + + fn fold_declaration_type( + &mut self, + t: DeclarationType<'ast>, + ) -> Result, Self::Error> { + Ok(t) + } + + fn fold_assignee( + &mut self, + a: TypedAssignee<'ast, T>, + ) -> Result, Self::Error> { + match a { + TypedAssignee::Identifier(v) => Ok(TypedAssignee::Identifier(self.fold_variable(v)?)), + TypedAssignee::Select(box a, box index) => Ok(TypedAssignee::Select( + box self.fold_assignee(a)?, + box self.fold_uint_expression(index)?, + )), + TypedAssignee::Member(box s, m) => { + Ok(TypedAssignee::Member(box self.fold_assignee(s)?, m)) + } + } + } + + fn fold_statement( + &mut self, + s: TypedStatement<'ast, T>, + ) -> Result>, Self::Error> { + fold_statement(self, s) + } + + fn fold_expression_or_spread( + &mut self, + e: TypedExpressionOrSpread<'ast, T>, + ) -> Result, Self::Error> { + Ok(match e { + TypedExpressionOrSpread::Expression(e) => { + TypedExpressionOrSpread::Expression(self.fold_expression(e)?) + } + TypedExpressionOrSpread::Spread(s) => { + TypedExpressionOrSpread::Spread(self.fold_spread(s)?) + } + }) + } + + fn fold_spread( + &mut self, + s: TypedSpread<'ast, T>, + ) -> Result, Self::Error> { + Ok(TypedSpread { + array: self.fold_array_expression(s.array)?, + }) + } + + fn fold_expression( + &mut self, + e: TypedExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + TypedExpression::FieldElement(e) => Ok(self.fold_field_expression(e)?.into()), + TypedExpression::Boolean(e) => Ok(self.fold_boolean_expression(e)?.into()), + TypedExpression::Uint(e) => Ok(self.fold_uint_expression(e)?.into()), + TypedExpression::Array(e) => Ok(self.fold_array_expression(e)?.into()), + TypedExpression::Struct(e) => Ok(self.fold_struct_expression(e)?.into()), + TypedExpression::Int(e) => Ok(self.fold_int_expression(e)?.into()), + } + } + + fn fold_array_expression( + &mut self, + e: ArrayExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_array_expression(self, e) + } + + fn fold_struct_expression( + &mut self, + e: StructExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_struct_expression(self, e) + } + + fn fold_expression_list( + &mut self, + es: TypedExpressionList<'ast, T>, + ) -> Result, Self::Error> { + fold_expression_list(self, es) + } + + fn fold_int_expression( + &mut self, + e: IntExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_int_expression(self, e) + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_field_expression(self, e) + } + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_boolean_expression(self, e) + } + fn fold_uint_expression( + &mut self, + e: UExpression<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression(self, e) + } + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_uint_expression_inner(self, bitwidth, e) + } + + fn fold_array_expression_inner( + &mut self, + ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_array_expression_inner(self, ty, e) + } + fn fold_struct_expression_inner( + &mut self, + ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + fold_struct_expression_inner(self, ty, e) + } +} + +pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedStatement<'ast, T>, +) -> Result>, F::Error> { + let res = match s { + TypedStatement::Return(expressions) => TypedStatement::Return( + expressions + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ), + TypedStatement::Definition(a, e) => { + TypedStatement::Definition(f.fold_assignee(a)?, f.fold_expression(e)?) + } + TypedStatement::Declaration(v) => TypedStatement::Declaration(f.fold_variable(v)?), + TypedStatement::Assertion(e) => TypedStatement::Assertion(f.fold_boolean_expression(e)?), + TypedStatement::For(v, from, to, statements) => TypedStatement::For( + f.fold_variable(v)?, + f.fold_uint_expression(from)?, + f.fold_uint_expression(to)?, + statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ), + TypedStatement::MultipleDefinition(variables, elist) => TypedStatement::MultipleDefinition( + variables + .into_iter() + .map(|v| f.fold_assignee(v)) + .collect::>()?, + f.fold_expression_list(elist)?, + ), + s => s, + }; + Ok(vec![res]) +} + +pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, +) -> Result, F::Error> { + let e = match e { + ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?), + ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value( + exprs + .into_iter() + .map(|e| f.fold_expression_or_spread(e)) + .collect::>()?, + ), + ArrayExpressionInner::FunctionCall(id, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?; + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + ArrayExpressionInner::FunctionCall(id, generics, exps) + } + ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { + ArrayExpressionInner::IfElse( + box f.fold_boolean_expression(condition)?, + box f.fold_array_expression(consequence)?, + box f.fold_array_expression(alternative)?, + ) + } + ArrayExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + ArrayExpressionInner::Member(box s, id) + } + ArrayExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + ArrayExpressionInner::Select(box array, box index) + } + ArrayExpressionInner::Slice(box array, box from, box to) => { + let array = f.fold_array_expression(array)?; + let from = f.fold_uint_expression(from)?; + let to = f.fold_uint_expression(to)?; + ArrayExpressionInner::Slice(box array, box from, box to) + } + ArrayExpressionInner::Repeat(box e, box count) => { + let e = f.fold_expression(e)?; + let count = f.fold_uint_expression(count)?; + ArrayExpressionInner::Repeat(box e, box count) + } + }; + Ok(e) +} + +pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, +) -> Result, F::Error> { + let e = match e { + StructExpressionInner::Identifier(id) => { + StructExpressionInner::Identifier(f.fold_name(id)?) + } + StructExpressionInner::Value(exprs) => StructExpressionInner::Value( + exprs + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?, + ), + StructExpressionInner::FunctionCall(id, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?; + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + StructExpressionInner::FunctionCall(id, generics, exps) + } + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + StructExpressionInner::IfElse( + box f.fold_boolean_expression(condition)?, + box f.fold_struct_expression(consequence)?, + box f.fold_struct_expression(alternative)?, + ) + } + StructExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + StructExpressionInner::Member(box s, id) + } + StructExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + StructExpressionInner::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: FieldElementExpression<'ast, T>, +) -> Result, F::Error> { + let e = match e { + FieldElementExpression::Number(n) => FieldElementExpression::Number(n), + FieldElementExpression::Identifier(id) => { + FieldElementExpression::Identifier(f.fold_name(id)?) + } + FieldElementExpression::Add(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Add(box e1, box e2) + } + FieldElementExpression::Sub(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Sub(box e1, box e2) + } + FieldElementExpression::Mult(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Mult(box e1, box e2) + } + FieldElementExpression::Div(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + FieldElementExpression::Div(box e1, box e2) + } + FieldElementExpression::Pow(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + FieldElementExpression::Pow(box e1, box e2) + } + FieldElementExpression::Neg(box e) => { + let e = f.fold_field_expression(e)?; + + FieldElementExpression::Neg(box e) + } + FieldElementExpression::Pos(box e) => { + let e = f.fold_field_expression(e)?; + + FieldElementExpression::Pos(box e) + } + FieldElementExpression::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_field_expression(cons)?; + let alt = f.fold_field_expression(alt)?; + FieldElementExpression::IfElse(box cond, box cons, box alt) + } + FieldElementExpression::FunctionCall(key, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?; + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + FieldElementExpression::FunctionCall(key, generics, exps) + } + FieldElementExpression::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + FieldElementExpression::Member(box s, id) + } + FieldElementExpression::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + FieldElementExpression::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + _: &mut F, + _: IntExpression<'ast, T>, +) -> Result, F::Error> { + unreachable!() +} + +pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: BooleanExpression<'ast, T>, +) -> Result, F::Error> { + let e = match e { + BooleanExpression::Value(v) => BooleanExpression::Value(v), + BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?), + BooleanExpression::FieldEq(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldEq(box e1, box e2) + } + BooleanExpression::BoolEq(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::BoolEq(box e1, box e2) + } + BooleanExpression::ArrayEq(box e1, box e2) => { + let e1 = f.fold_array_expression(e1)?; + let e2 = f.fold_array_expression(e2)?; + BooleanExpression::ArrayEq(box e1, box e2) + } + BooleanExpression::StructEq(box e1, box e2) => { + let e1 = f.fold_struct_expression(e1)?; + let e2 = f.fold_struct_expression(e2)?; + BooleanExpression::StructEq(box e1, box e2) + } + BooleanExpression::UintEq(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintEq(box e1, box e2) + } + BooleanExpression::FieldLt(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldLt(box e1, box e2) + } + BooleanExpression::FieldLe(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldLe(box e1, box e2) + } + BooleanExpression::FieldGt(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldGt(box e1, box e2) + } + BooleanExpression::FieldGe(box e1, box e2) => { + let e1 = f.fold_field_expression(e1)?; + let e2 = f.fold_field_expression(e2)?; + BooleanExpression::FieldGe(box e1, box e2) + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintLt(box e1, box e2) + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintLe(box e1, box e2) + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintGt(box e1, box e2) + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1)?; + let e2 = f.fold_uint_expression(e2)?; + BooleanExpression::UintGe(box e1, box e2) + } + BooleanExpression::Or(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::Or(box e1, box e2) + } + BooleanExpression::And(box e1, box e2) => { + let e1 = f.fold_boolean_expression(e1)?; + let e2 = f.fold_boolean_expression(e2)?; + BooleanExpression::And(box e1, box e2) + } + BooleanExpression::Not(box e) => { + let e = f.fold_boolean_expression(e)?; + BooleanExpression::Not(box e) + } + BooleanExpression::FunctionCall(key, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?; + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + BooleanExpression::FunctionCall(key, generics, exps) + } + BooleanExpression::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_boolean_expression(cons)?; + let alt = f.fold_boolean_expression(alt)?; + BooleanExpression::IfElse(box cond, box cons, box alt) + } + BooleanExpression::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + BooleanExpression::Member(box s, id) + } + BooleanExpression::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + BooleanExpression::Select(box array, box index) + } + }; + Ok(e) +} + +pub fn fold_uint_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: UExpression<'ast, T>, +) -> Result, F::Error> { + Ok(UExpression { + inner: f.fold_uint_expression_inner(e.bitwidth, e.inner)?, + ..e + }) +} + +pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + _: UBitwidth, + e: UExpressionInner<'ast, T>, +) -> Result, F::Error> { + let e = match e { + UExpressionInner::Value(v) => UExpressionInner::Value(v), + UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?), + UExpressionInner::Add(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Add(box left, box right) + } + UExpressionInner::Sub(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Sub(box left, box right) + } + UExpressionInner::FloorSub(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::FloorSub(box left, box right) + } + UExpressionInner::Mult(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Mult(box left, box right) + } + UExpressionInner::Div(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Div(box left, box right) + } + UExpressionInner::Rem(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Rem(box left, box right) + } + UExpressionInner::Xor(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Xor(box left, box right) + } + UExpressionInner::And(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::And(box left, box right) + } + UExpressionInner::Or(box left, box right) => { + let left = f.fold_uint_expression(left)?; + let right = f.fold_uint_expression(right)?; + + UExpressionInner::Or(box left, box right) + } + UExpressionInner::LeftShift(box e, box by) => { + let e = f.fold_uint_expression(e)?; + let by = f.fold_uint_expression(by)?; + + UExpressionInner::LeftShift(box e, box by) + } + UExpressionInner::RightShift(box e, box by) => { + let e = f.fold_uint_expression(e)?; + let by = f.fold_uint_expression(by)?; + + UExpressionInner::RightShift(box e, box by) + } + UExpressionInner::Not(box e) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::Not(box e) + } + UExpressionInner::Neg(box e) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::Neg(box e) + } + UExpressionInner::Pos(box e) => { + let e = f.fold_uint_expression(e)?; + + UExpressionInner::Pos(box e) + } + UExpressionInner::FunctionCall(key, generics, exps) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?; + let exps = exps + .into_iter() + .map(|e| f.fold_expression(e)) + .collect::>()?; + UExpressionInner::FunctionCall(key, generics, exps) + } + UExpressionInner::Select(box array, box index) => { + let array = f.fold_array_expression(array)?; + let index = f.fold_uint_expression(index)?; + UExpressionInner::Select(box array, box index) + } + UExpressionInner::IfElse(box cond, box cons, box alt) => { + let cond = f.fold_boolean_expression(cond)?; + let cons = f.fold_uint_expression(cons)?; + let alt = f.fold_uint_expression(alt)?; + UExpressionInner::IfElse(box cond, box cons, box alt) + } + UExpressionInner::Member(box s, id) => { + let s = f.fold_struct_expression(s)?; + UExpressionInner::Member(box s, id) + } + }; + Ok(e) +} + +pub fn fold_function<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + fun: TypedFunction<'ast, T>, +) -> Result, F::Error> { + Ok(TypedFunction { + arguments: fun + .arguments + .into_iter() + .map(|a| f.fold_parameter(a)) + .collect::>()?, + statements: fun + .statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + ..fun + }) +} + +pub fn fold_array_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: ArrayExpression<'ast, T>, +) -> Result, F::Error> { + let ty = f.fold_array_type(*e.ty)?; + + Ok(ArrayExpression { + inner: f.fold_array_expression_inner(&ty, e.inner)?, + ty: box ty, + }) +} + +pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + es: TypedExpressionList<'ast, T>, +) -> Result, F::Error> { + match es { + TypedExpressionList::FunctionCall(id, generics, arguments, types) => { + let generics = generics + .into_iter() + .map(|g| g.map(|g| f.fold_uint_expression(g)).transpose()) + .collect::>()?; + Ok(TypedExpressionList::FunctionCall( + id, + generics, + arguments + .into_iter() + .map(|a| f.fold_expression(a)) + .collect::>()?, + types + .into_iter() + .map(|t| f.fold_type(t)) + .collect::>()?, + )) + } + TypedExpressionList::EmbedCall(embed, generics, arguments, types) => { + Ok(TypedExpressionList::EmbedCall( + embed, + generics, + arguments + .into_iter() + .map(|a| f.fold_expression(a)) + .collect::>()?, + types + .into_iter() + .map(|t| f.fold_type(t)) + .collect::>()?, + )) + } + } +} + +pub fn fold_struct_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + e: StructExpression<'ast, T>, +) -> Result, F::Error> { + Ok(StructExpression { + inner: f.fold_struct_expression_inner(&e.ty, e.inner)?, + ..e + }) +} + +pub fn fold_function_symbol<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + s: TypedFunctionSymbol<'ast, T>, +) -> Result, F::Error> { + match s { + TypedFunctionSymbol::Here(fun) => Ok(TypedFunctionSymbol::Here(f.fold_function(fun)?)), + there => Ok(there), // by default, do not fold modules recursively + } +} + +pub fn fold_module<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + p: TypedModule<'ast, T>, +) -> Result, F::Error> { + Ok(TypedModule { + functions: p + .functions + .into_iter() + .map(|(key, fun)| f.fold_function_symbol(fun).map(|f| (key, f))) + .collect::>()?, + }) +} + +pub fn fold_program<'ast, T: Field, F: ResultFolder<'ast, T>>( + f: &mut F, + p: TypedProgram<'ast, T>, +) -> Result, F::Error> { + Ok(TypedProgram { + modules: p + .modules + .into_iter() + .map(|(module_id, module)| f.fold_module(module).map(|m| (module_id, m))) + .collect::>()?, + main: p.main, + }) +} diff --git a/zokrates_core/src/typed_absy/types.rs b/zokrates_core/src/typed_absy/types.rs index e1831096..6f6ad274 100644 --- a/zokrates_core/src/typed_absy/types.rs +++ b/zokrates_core/src/typed_absy/types.rs @@ -1,24 +1,257 @@ +use crate::typed_absy::{TryFrom, TryInto}; +use crate::typed_absy::{TypedModuleId, UExpression, UExpressionInner}; use serde::{de::Error, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; +use std::collections::BTreeMap; use std::fmt; +use std::hash::{Hash, Hasher}; use std::path::{Path, PathBuf}; -pub type Identifier<'ast> = &'ast str; +pub type GenericIdentifier<'ast> = &'ast str; + +#[derive(Debug)] +pub struct SpecializationError; + +#[derive(Debug, Clone)] +pub enum Constant<'ast> { + Generic(GenericIdentifier<'ast>), + Concrete(u32), +} + +// At this stage we want all constants to be equal +impl<'ast> PartialEq for Constant<'ast> { + fn eq(&self, _: &Self) -> bool { + true + } +} + +impl<'ast> PartialOrd for Constant<'ast> { + fn partial_cmp(&self, _: &Self) -> std::option::Option { + Some(std::cmp::Ordering::Equal) + } +} + +impl<'ast> Ord for Constant<'ast> { + fn cmp(&self, _: &Self) -> std::cmp::Ordering { + std::cmp::Ordering::Equal + } +} + +impl<'ast> Eq for Constant<'ast> {} + +impl<'ast> Hash for Constant<'ast> { + fn hash(&self, _: &mut H) + where + H: Hasher, + { + // we do not hash anything, as we want all constant to hash to the same thing + } +} + +impl<'ast> From for Constant<'ast> { + fn from(e: u32) -> Self { + Constant::Concrete(e) + } +} + +impl<'ast> From for Constant<'ast> { + fn from(e: usize) -> Self { + Constant::Concrete(e as u32) + } +} + +impl<'ast> From> for Constant<'ast> { + fn from(e: GenericIdentifier<'ast>) -> Self { + Constant::Generic(e) + } +} + +impl<'ast> fmt::Display for Constant<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Constant::Generic(i) => write!(f, "{}", i), + Constant::Concrete(v) => write!(f, "{}", v), + } + } +} + +impl<'ast, T> From for UExpression<'ast, T> { + fn from(i: usize) -> Self { + UExpressionInner::Value(i as u128).annotate(UBitwidth::B32) + } +} + +impl<'ast, T> From> for UExpression<'ast, T> { + fn from(c: Constant<'ast>) -> Self { + match c { + Constant::Generic(i) => UExpressionInner::Identifier(i.into()).annotate(UBitwidth::B32), + Constant::Concrete(v) => UExpressionInner::Value(v as u128).annotate(UBitwidth::B32), + } + } +} + +impl<'ast, T> TryInto for UExpression<'ast, T> { + type Error = SpecializationError; + + fn try_into(self) -> Result { + assert_eq!(self.bitwidth, UBitwidth::B32); + + match self.into_inner() { + UExpressionInner::Value(v) => Ok(v as usize), + _ => Err(SpecializationError), + } + } +} + +impl<'ast> TryInto for Constant<'ast> { + type Error = SpecializationError; + + fn try_into(self) -> Result { + match self { + Constant::Concrete(v) => Ok(v as usize), + _ => Err(SpecializationError), + } + } +} pub type MemberId = String; #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] -pub struct StructMember { +pub struct GStructMember { #[serde(rename = "name")] pub id: MemberId, #[serde(flatten)] - pub ty: Box, + pub ty: Box>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] -pub struct ArrayType { - pub size: usize, +pub type DeclarationStructMember<'ast> = GStructMember>; +pub type ConcreteStructMember = GStructMember; +pub type StructMember<'ast, T> = GStructMember>; + +impl<'ast, T: PartialEq> PartialEq> for StructMember<'ast, T> { + fn eq(&self, other: &DeclarationStructMember<'ast>) -> bool { + self.id == other.id && *self.ty == *other.ty + } +} + +fn try_from_g_struct_member, U>( + t: GStructMember, +) -> Result, SpecializationError> { + Ok(GStructMember { + id: t.id, + ty: box try_from_g_type(*t.ty)?, + }) +} + +impl<'ast, T> TryFrom> for ConcreteStructMember { + type Error = SpecializationError; + + fn try_from(t: StructMember<'ast, T>) -> Result { + try_from_g_struct_member(t) + } +} + +impl<'ast, T> From for StructMember<'ast, T> { + fn from(t: ConcreteStructMember) -> Self { + try_from_g_struct_member(t).unwrap() + } +} + +impl<'ast> From for DeclarationStructMember<'ast> { + fn from(t: ConcreteStructMember) -> Self { + try_from_g_struct_member(t).unwrap() + } +} + +impl<'ast, T> From> for StructMember<'ast, T> { + fn from(t: DeclarationStructMember<'ast>) -> Self { + try_from_g_struct_member(t).unwrap() + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +pub struct GArrayType { + pub size: S, #[serde(flatten)] - pub ty: Box, + pub ty: Box>, +} + +pub type DeclarationArrayType<'ast> = GArrayType>; +pub type ConcreteArrayType = GArrayType; +pub type ArrayType<'ast, T> = GArrayType>; + +impl<'ast, T: PartialEq> PartialEq> for ArrayType<'ast, T> { + fn eq(&self, other: &DeclarationArrayType<'ast>) -> bool { + *self.ty == *other.ty + && match (self.size.as_inner(), &other.size) { + (UExpressionInner::Value(l), Constant::Concrete(r)) => *l as u32 == *r, + _ => true, + } + } +} + +impl fmt::Display for GArrayType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt_aux<'a, S: fmt::Display>( + f: &mut fmt::Formatter, + t: &'a GArrayType, + mut acc: Vec<&'a S>, + ) -> fmt::Result { + acc.push(&t.size); + match &*t.ty { + GType::Array(array_type) => fmt_aux(f, &array_type, acc), + t => { + write!(f, "{}", t)?; + for i in acc { + write!(f, "[{}]", i)?; + } + write!(f, "") + } + } + } + + let acc = vec![]; + + fmt_aux(f, &self, acc) + } +} + +impl<'ast, T: PartialEq + fmt::Display> Type<'ast, T> { + // array type equality with non-strict size checks + // sizes always match unless they are different constants + pub fn weak_eq(&self, other: &Self) -> bool { + match (self, other) { + (Type::Array(t), Type::Array(u)) => t.ty.weak_eq(&u.ty), + (Type::Struct(t), Type::Struct(u)) => t + .members + .iter() + .zip(u.members.iter()) + .all(|(m, n)| m.ty.weak_eq(&n.ty)), + (t, u) => t == u, + } + } +} + +fn try_from_g_array_type, U>( + t: GArrayType, +) -> Result, SpecializationError> { + Ok(GArrayType { + size: t.size.try_into().map_err(|_| SpecializationError)?, + ty: box try_from_g_type(*t.ty)?, + }) +} + +impl<'ast, T> TryFrom> for ConcreteArrayType { + type Error = SpecializationError; + + fn try_from(t: ArrayType<'ast, T>) -> Result { + try_from_g_array_type(t) + } +} + +impl<'ast, T> From for ArrayType<'ast, T> { + fn from(t: ConcreteArrayType) -> Self { + try_from_g_array_type(t).unwrap() + } } #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialOrd, Ord, Eq, PartialEq)] @@ -28,37 +261,99 @@ pub struct StructLocation { pub name: String, } -#[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialOrd, Ord)] -pub struct StructType { +impl<'ast> From for DeclarationArrayType<'ast> { + fn from(t: ConcreteArrayType) -> Self { + try_from_g_array_type(t).unwrap() + } +} + +impl<'ast, T> From> for ArrayType<'ast, T> { + fn from(t: DeclarationArrayType<'ast>) -> Self { + try_from_g_array_type(t).unwrap() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialOrd, Ord)] +pub struct GStructType { #[serde(flatten)] pub canonical_location: StructLocation, #[serde(skip)] pub location: Option, - pub members: Vec, + pub members: Vec>, } -impl PartialEq for StructType { +pub type DeclarationStructType<'ast> = GStructType>; +pub type ConcreteStructType = GStructType; +pub type StructType<'ast, T> = GStructType>; + +impl PartialEq for GStructType { fn eq(&self, other: &Self) -> bool { - self.canonical_location.eq(&other.canonical_location) && self.members.eq(&other.members) + self.canonical_location.eq(&other.canonical_location) } } -impl Eq for StructType {} +impl Hash for GStructType { + fn hash(&self, state: &mut H) { + self.canonical_location.hash(state); + } +} -impl StructType { - pub fn new(module: PathBuf, name: String, members: Vec) -> Self { - StructType { +impl Eq for GStructType {} + +fn try_from_g_struct_type, U>( + t: GStructType, +) -> Result, SpecializationError> { + Ok(GStructType { + location: t.location, + canonical_location: t.canonical_location, + members: t + .members + .into_iter() + .map(try_from_g_struct_member) + .collect::>()?, + }) +} + +impl<'ast, T> TryFrom> for ConcreteStructType { + type Error = SpecializationError; + + fn try_from(t: StructType<'ast, T>) -> Result { + try_from_g_struct_type(t) + } +} + +impl<'ast, T> From for StructType<'ast, T> { + fn from(t: ConcreteStructType) -> Self { + try_from_g_struct_type(t).unwrap() + } +} + +impl<'ast> From for DeclarationStructType<'ast> { + fn from(t: ConcreteStructType) -> Self { + try_from_g_struct_type(t).unwrap() + } +} + +impl<'ast, T> From> for StructType<'ast, T> { + fn from(t: DeclarationStructType<'ast>) -> Self { + try_from_g_struct_type(t).unwrap() + } +} + +impl GStructType { + pub fn new(module: PathBuf, name: String, members: Vec>) -> Self { + GStructType { canonical_location: StructLocation { module, name }, location: None, members, } } - pub fn len(&self) -> usize { + pub fn members_count(&self) -> usize { self.members.len() } - pub fn iter(&self) -> std::slice::Iter { + pub fn iter(&self) -> std::slice::Iter> { self.members.iter() } @@ -75,8 +370,8 @@ impl StructType { } } -impl IntoIterator for StructType { - type Item = StructMember; +impl IntoIterator for GStructType { + type Item = GStructMember; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { @@ -118,144 +413,221 @@ impl fmt::Display for UBitwidth { } #[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum Type { +pub enum GType { FieldElement, Boolean, - Array(ArrayType), - Struct(StructType), + Array(GArrayType), + Struct(GStructType), Uint(UBitwidth), + Int, } -impl serde::Serialize for Type { +impl Serialize for GType { fn serialize(&self, s: S) -> Result<::Ok, ::Error> where S: Serializer, { + use serde::ser::Error; + match self { - Type::FieldElement => s.serialize_newtype_variant("Type", 0, "type", "field"), - Type::Boolean => s.serialize_newtype_variant("Type", 1, "type", "bool"), - Type::Array(array_type) => { + GType::FieldElement => s.serialize_newtype_variant("Type", 0, "type", "field"), + GType::Boolean => s.serialize_newtype_variant("Type", 1, "type", "bool"), + GType::Array(array_type) => { let mut map = s.serialize_map(Some(2))?; map.serialize_entry("type", "array")?; map.serialize_entry("components", array_type)?; map.end() } - Type::Struct(struct_type) => { + GType::Struct(struct_type) => { let mut map = s.serialize_map(Some(2))?; map.serialize_entry("type", "struct")?; map.serialize_entry("components", struct_type)?; map.end() } - Type::Uint(width) => s.serialize_newtype_variant( + GType::Uint(width) => s.serialize_newtype_variant( "Type", 4, "type", format!("u{}", width.to_usize()).as_str(), ), + GType::Int => Err(S::Error::custom( + "Cannot serialize Int type as it's not allowed in function signatures".to_string(), + )), } } } -impl<'de> serde::Deserialize<'de> for Type { +impl<'de, S: Deserialize<'de>> Deserialize<'de> for GType { fn deserialize(d: D) -> Result>::Error> where D: Deserializer<'de>, { - #[derive(Debug, Deserialize)] + #[derive(Deserialize)] #[serde(untagged)] - enum Components { - Array(ArrayType), - Struct(StructType), + enum Components { + Array(GArrayType), + Struct(GStructType), } - #[derive(Debug, Deserialize)] - struct Mapping { + #[derive(Deserialize)] + struct Mapping { #[serde(rename = "type")] ty: String, - components: Option, + components: Option>, } - let strict_type = |m: Mapping, ty: Type| -> Result>::Error> { - match m.components { - Some(_) => Err(D::Error::custom(format!( - "unexpected `components` field for type `{}`", - ty - ))), - None => Ok(ty), - } - }; + let strict_type = + |m: Mapping, ty: GType| -> Result>::Error> { + match m.components { + Some(_) => Err(D::Error::custom(format!( + "unexpected `components` field in type {}", + m.ty + ))), + None => Ok(ty), + } + }; let mapping = Mapping::deserialize(d)?; match mapping.ty.as_str() { - "field" => strict_type(mapping, Type::FieldElement), - "bool" => strict_type(mapping, Type::Boolean), + "field" => strict_type(mapping, GType::FieldElement), + "bool" => strict_type(mapping, GType::Boolean), "array" => { - let components = mapping.components.ok_or(D::Error::custom(format_args!( - "missing `components` field for type `{}'", - mapping.ty - )))?; + let components = mapping + .components + .ok_or_else(|| D::Error::custom("missing `components` field".to_string()))?; match components { - Components::Array(array_type) => Ok(Type::Array(array_type)), - _ => Err(D::Error::custom(format!( - "invalid `components` variant for type `{}`", - mapping.ty - ))), + Components::Array(array_type) => Ok(GType::Array(array_type)), + _ => Err(D::Error::custom("invalid `components` variant".to_string())), } } "struct" => { - let components = mapping.components.ok_or(D::Error::custom(format_args!( - "missing `components` field for type `{}'", - mapping.ty - )))?; + let components = mapping + .components + .ok_or_else(|| D::Error::custom("missing `components` field".to_string()))?; match components { - Components::Struct(struct_type) => Ok(Type::Struct(struct_type)), - _ => Err(D::Error::custom(format!( - "invalid `components` variant for type `{}`", - mapping.ty - ))), + Components::Struct(struct_type) => Ok(GType::Struct(struct_type)), + _ => Err(D::Error::custom("invalid `components` variant".to_string())), } } - "u8" => strict_type(mapping, Type::Uint(UBitwidth::B8)), - "u16" => strict_type(mapping, Type::Uint(UBitwidth::B16)), - "u32" => strict_type(mapping, Type::Uint(UBitwidth::B32)), - _ => Err(D::Error::custom(format!("invalid type `{}`", mapping.ty))), + "u8" => strict_type(mapping, GType::Uint(UBitwidth::B8)), + "u16" => strict_type(mapping, GType::Uint(UBitwidth::B16)), + "u32" => strict_type(mapping, GType::Uint(UBitwidth::B32)), + t => Err(D::Error::custom(format!("invalid type `{}`", t))), } } } -impl ArrayType { - pub fn new(ty: Type, size: usize) -> Self { - ArrayType { +pub type DeclarationType<'ast> = GType>; +pub type ConcreteType = GType; +pub type Type<'ast, T> = GType>; + +impl<'ast, T: PartialEq> PartialEq> for Type<'ast, T> { + fn eq(&self, other: &DeclarationType<'ast>) -> bool { + use self::GType::*; + + match (self, other) { + (Array(l), Array(r)) => l == r, + (Struct(l), Struct(r)) => l.canonical_location == r.canonical_location, + (FieldElement, FieldElement) | (Boolean, Boolean) => true, + (Uint(l), Uint(r)) => l == r, + _ => false, + } + } +} + +fn try_from_g_type, U>(t: GType) -> Result, SpecializationError> { + match t { + GType::FieldElement => Ok(GType::FieldElement), + GType::Boolean => Ok(GType::Boolean), + GType::Int => Ok(GType::Int), + GType::Uint(bitwidth) => Ok(GType::Uint(bitwidth)), + GType::Array(array_type) => Ok(GType::Array(try_from_g_array_type(array_type)?)), + GType::Struct(struct_type) => Ok(GType::Struct(try_from_g_struct_type(struct_type)?)), + } +} + +impl<'ast, T> TryFrom> for ConcreteType { + type Error = SpecializationError; + + fn try_from(t: Type<'ast, T>) -> Result { + try_from_g_type(t) + } +} + +impl<'ast, T> From for Type<'ast, T> { + fn from(t: ConcreteType) -> Self { + try_from_g_type(t).unwrap() + } +} + +impl<'ast> From for DeclarationType<'ast> { + fn from(t: ConcreteType) -> Self { + try_from_g_type(t).unwrap() + } +} + +impl<'ast, T> From> for Type<'ast, T> { + fn from(t: DeclarationType<'ast>) -> Self { + try_from_g_type(t).unwrap() + } +} + +impl> From<(GType, U)> for GArrayType { + fn from(tup: (GType, U)) -> Self { + GArrayType { + ty: box tup.0, + size: tup.1.into(), + } + } +} + +impl GArrayType { + pub fn new(ty: GType, size: S) -> Self { + GArrayType { ty: Box::new(ty), size, } } } -impl StructMember { - pub fn new(id: String, ty: Type) -> Self { - StructMember { +impl GStructMember { + pub fn new(id: String, ty: GType) -> Self { + GStructMember { id, ty: Box::new(ty), } } } -impl fmt::Display for Type { +impl fmt::Display for GType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Type::FieldElement => write!(f, "field"), - Type::Boolean => write!(f, "bool"), - Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), - Type::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), - Type::Struct(ref struct_type) => write!( + GType::FieldElement => write!(f, "field"), + GType::Boolean => write!(f, "bool"), + GType::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), + GType::Int => write!(f, "{{integer}}"), + GType::Array(ref array_type) => write!(f, "{}", array_type), + GType::Struct(ref struct_type) => write!(f, "{}", struct_type.name(),), + } + } +} + +impl fmt::Debug for GType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + GType::FieldElement => write!(f, "field"), + GType::Boolean => write!(f, "bool"), + GType::Int => write!(f, "integer"), + GType::Uint(ref bitwidth) => write!(f, "u{:?}", bitwidth), + GType::Array(ref array_type) => write!(f, "{:?}[{:?}]", array_type.ty, array_type.size), + GType::Struct(ref struct_type) => write!( f, - "{} {{{}}}", + "{:?} {{{:?}}}", struct_type.name(), struct_type .members .iter() - .map(|member| format!("{}: {}", member.id, member.ty)) + .map(|member| format!("{:?}: {:?}", member.id, member.ty)) .collect::>() .join(", ") ), @@ -263,38 +635,48 @@ impl fmt::Display for Type { } } -impl fmt::Debug for Type { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Type::FieldElement => write!(f, "field"), - Type::Boolean => write!(f, "bool"), - Type::Uint(ref bitwidth) => write!(f, "u{}", bitwidth), - Type::Array(ref array_type) => write!(f, "{}[{}]", array_type.ty, array_type.size), - Type::Struct(ref struct_type) => write!(f, "{:?}", struct_type), +impl GType { + pub fn array>>(array_ty: U) -> Self { + GType::Array(array_ty.into()) + } + + pub fn struc>>(struct_ty: U) -> Self { + GType::Struct(struct_ty.into()) + } + + pub fn uint>(b: W) -> Self { + GType::Uint(b.into()) + } +} + +impl<'ast, T: fmt::Display + PartialEq + fmt::Debug> Type<'ast, T> { + pub fn can_be_specialized_to(&self, other: &DeclarationType) -> bool { + use self::GType::*; + + if self == other { + true + } else { + match (self, other) { + (Int, FieldElement) | (Int, Uint(..)) => true, + (Array(l), Array(r)) => l.ty.can_be_specialized_to(&r.ty), + // types do not come into play for Struct equality, only the canonical location. Hence no inference + // can change anything + (Struct(_), Struct(_)) => false, + _ => false, + } } } } -impl Type { - pub fn array(ty: Type, size: usize) -> Self { - Type::Array(ArrayType::new(ty, size)) - } - - pub fn struc(struct_ty: StructType) -> Self { - Type::Struct(struct_ty) - } - - pub fn uint>(b: W) -> Self { - Type::Uint(b.into()) - } - +impl ConcreteType { fn to_slug(&self) -> String { match self { - Type::FieldElement => String::from("f"), - Type::Boolean => String::from("b"), - Type::Uint(bitwidth) => format!("u{}", bitwidth), - Type::Array(array_type) => format!("{}[{}]", array_type.ty.to_slug(), array_type.size), - Type::Struct(struct_type) => format!( + GType::FieldElement => String::from("f"), + GType::Int => unreachable!(), + GType::Boolean => String::from("b"), + GType::Uint(bitwidth) => format!("u{}", bitwidth), + GType::Array(array_type) => format!("{}[{}]", array_type.ty.to_slug(), array_type.size), + GType::Struct(struct_type) => format!( "{{{}}}", struct_type .iter() @@ -304,15 +686,18 @@ impl Type { ), } } +} +impl ConcreteType { // the number of field elements the type maps to pub fn get_primitive_count(&self) -> usize { match self { - Type::FieldElement => 1, - Type::Boolean => 1, - Type::Uint(_) => 1, - Type::Array(array_type) => array_type.size * array_type.ty.get_primitive_count(), - Type::Struct(struct_type) => struct_type + GType::FieldElement => 1, + GType::Boolean => 1, + GType::Uint(_) => 1, + GType::Array(array_type) => array_type.size * array_type.ty.get_primitive_count(), + GType::Int => unreachable!(), + GType::Struct(struct_type) => struct_type .iter() .map(|member| member.ty.get_primitive_count()) .sum(), @@ -323,71 +708,418 @@ impl Type { pub type FunctionIdentifier<'ast> = &'ast str; #[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub struct FunctionKey<'ast> { +pub struct GFunctionKey<'ast, S> { + pub module: TypedModuleId, pub id: FunctionIdentifier<'ast>, - pub signature: Signature, + pub signature: GSignature, } -pub type FunctionKeyHash = u64; +pub type DeclarationFunctionKey<'ast> = GFunctionKey<'ast, Constant<'ast>>; +pub type ConcreteFunctionKey<'ast> = GFunctionKey<'ast, usize>; +pub type FunctionKey<'ast, T> = GFunctionKey<'ast, UExpression<'ast, T>>; -impl<'ast> FunctionKey<'ast> { - pub fn with_id>>(id: S) -> Self { - FunctionKey { +impl<'ast, S: fmt::Display> fmt::Display for GFunctionKey<'ast, S> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}/{}{}", self.module.display(), self.id, self.signature) + } +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct GGenericsAssignment<'ast, S>(pub BTreeMap, S>); + +pub type ConcreteGenericsAssignment<'ast> = GGenericsAssignment<'ast, usize>; +pub type GenericsAssignment<'ast, T> = GGenericsAssignment<'ast, UExpression<'ast, T>>; + +impl<'ast, S> Default for GGenericsAssignment<'ast, S> { + fn default() -> Self { + GGenericsAssignment(BTreeMap::new()) + } +} + +impl<'ast, S: fmt::Display> fmt::Display for GGenericsAssignment<'ast, S> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + self.0 + .iter() + .map(|(k, v)| format!("{}: {}", k, v)) + .collect::>() + .join(", ") + ) + } +} + +impl<'ast> PartialEq> for ConcreteFunctionKey<'ast> { + fn eq(&self, other: &DeclarationFunctionKey<'ast>) -> bool { + self.module == other.module && self.id == other.id && self.signature == other.signature + } +} + +fn try_from_g_function_key, U>( + k: GFunctionKey, +) -> Result, SpecializationError> { + Ok(GFunctionKey { + module: k.module, + signature: signature::try_from_g_signature(k.signature)?, + id: k.id, + }) +} + +impl<'ast, T> TryFrom> for ConcreteFunctionKey<'ast> { + type Error = SpecializationError; + + fn try_from(k: FunctionKey<'ast, T>) -> Result { + try_from_g_function_key(k) + } +} + +// impl<'ast> TryFrom> for ConcreteFunctionKey<'ast> { +// type Error = SpecializationError; + +// fn try_from(k: DeclarationFunctionKey<'ast>) -> Result { +// try_from_g_function_key(k) +// } +// } + +impl<'ast, T> From> for FunctionKey<'ast, T> { + fn from(k: ConcreteFunctionKey<'ast>) -> Self { + try_from_g_function_key(k).unwrap() + } +} + +impl<'ast> From> for DeclarationFunctionKey<'ast> { + fn from(k: ConcreteFunctionKey<'ast>) -> Self { + try_from_g_function_key(k).unwrap() + } +} + +impl<'ast, T> From> for FunctionKey<'ast, T> { + fn from(k: DeclarationFunctionKey<'ast>) -> Self { + try_from_g_function_key(k).unwrap() + } +} + +impl<'ast, S> GFunctionKey<'ast, S> { + pub fn with_location, U: Into>>( + module: T, + id: U, + ) -> Self { + GFunctionKey { + module: module.into(), id: id.into(), - signature: Signature::new(), + signature: GSignature::new(), } } - pub fn hash(&self) -> FunctionKeyHash { - use std::hash::Hasher; - - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - ::hash(self, &mut hasher); - hasher.finish() - } - - pub fn signature(mut self, signature: Signature) -> Self { + pub fn signature(mut self, signature: GSignature) -> Self { self.signature = signature; self } - pub fn id>>(mut self, id: S) -> Self { + pub fn id>>(mut self, id: U) -> Self { self.id = id.into(); self } - pub fn to_slug(&self) -> String { - format!("{}_{}", self.id, self.signature.to_slug()) + pub fn module>(mut self, module: T) -> Self { + self.module = module.into(); + self } } -pub use self::signature::Signature; -// use serde::de::Error; -// use serde::ser::SerializeMap; -// use serde::{Deserializer, Serializer}; +impl<'ast> ConcreteFunctionKey<'ast> { + pub fn to_slug(&self) -> String { + format!( + "{}/{}_{}", + self.module.display(), + self.id, + self.signature.to_slug() + ) + } +} + +pub use self::signature::{ConcreteSignature, DeclarationSignature, GSignature, Signature}; pub mod signature { use super::*; use std::fmt; - #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)] - pub struct Signature { - pub inputs: Vec, - pub outputs: Vec, + #[derive(Clone, Serialize, Deserialize, Eq)] + pub struct GSignature { + pub generics: Vec>, + pub inputs: Vec>, + pub outputs: Vec>, } - impl fmt::Debug for Signature { + impl PartialOrd for GSignature { + fn partial_cmp(&self, other: &Self) -> std::option::Option { + match self.inputs.partial_cmp(&other.inputs) { + Some(std::cmp::Ordering::Equal) => self.outputs.partial_cmp(&other.outputs), + r => r, + } + } + } + + impl Ord for GSignature { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(&other).unwrap() + } + } + + impl Hash for GSignature { + fn hash(&self, state: &mut H) { + self.inputs.hash(state); + self.outputs.hash(state); + } + } + + impl PartialEq for GSignature { + fn eq(&self, other: &GSignature) -> bool { + // we ignore generics as we want a generic function to conflict with its specialized (generics free) version + self.inputs == other.inputs && self.outputs == other.outputs + } + } + + impl Default for GSignature { + fn default() -> Self { + GSignature { + generics: vec![], + inputs: vec![], + outputs: vec![], + } + } + } + + pub type DeclarationSignature<'ast> = GSignature>; + pub type ConcreteSignature = GSignature; + pub type Signature<'ast, T> = GSignature>; + + use std::collections::btree_map::Entry; + + fn check_type<'ast, S: Clone + PartialEq + PartialEq>( + decl_ty: &DeclarationType<'ast>, + ty: >ype, + constants: &mut GGenericsAssignment<'ast, S>, + ) -> bool { + match (decl_ty, ty) { + (DeclarationType::Array(t0), GType::Array(t1)) => { + let s1 = t1.size.clone(); + + // both the inner type and the size must match + check_type(&t0.ty, &t1.ty, constants) + && match t0.size { + // if the declared size is an identifier, we insert into the map, or check if the concrete size + // matches if this identifier is already in the map + Constant::Generic(id) => match constants.0.entry(id) { + Entry::Occupied(e) => *e.get() == s1, + Entry::Vacant(e) => { + e.insert(s1); + true + } + }, + Constant::Concrete(s0) => s1 == s0 as usize, + } + } + (DeclarationType::FieldElement, GType::FieldElement) + | (DeclarationType::Boolean, GType::Boolean) => true, + (DeclarationType::Uint(b0), GType::Uint(b1)) => b0 == b1, + (DeclarationType::Struct(s0), GType::Struct(s1)) => { + s0.canonical_location == s1.canonical_location + } + _ => false, + } + } + + fn specialize_type<'ast, S: Clone + PartialEq + PartialEq + From + fmt::Debug>( + decl_ty: DeclarationType<'ast>, + constants: &GGenericsAssignment<'ast, S>, + ) -> Result, GenericIdentifier<'ast>> { + Ok(match decl_ty { + DeclarationType::Int => unreachable!(), + DeclarationType::Array(t0) => { + // let s1 = t1.size.clone(); + + let ty = box specialize_type(*t0.ty, &constants)?; + let size = match t0.size { + Constant::Generic(s) => constants.0.get(&s).cloned().ok_or(s), + Constant::Concrete(s) => Ok(s.into()), + }?; + + GType::Array(GArrayType { ty, size }) + } + DeclarationType::FieldElement => GType::FieldElement, + DeclarationType::Boolean => GType::Boolean, + DeclarationType::Uint(b0) => GType::Uint(b0), + DeclarationType::Struct(s0) => GType::Struct(GStructType { + members: s0 + .members + .into_iter() + .map(|m| { + let id = m.id; + specialize_type(*m.ty, constants).map(|ty| GStructMember { ty: box ty, id }) + }) + .collect::>()?, + canonical_location: s0.canonical_location, + location: s0.location, + }), + }) + } + + impl<'ast> PartialEq> for ConcreteSignature { + fn eq(&self, other: &DeclarationSignature<'ast>) -> bool { + // we keep track of the value of constants in a map, as a given constant can only have one value + let mut constants = ConcreteGenericsAssignment::default(); + + other + .inputs + .iter() + .chain(other.outputs.iter()) + .zip(self.inputs.iter().chain(self.outputs.iter())) + .all(|(decl_ty, ty)| check_type::(decl_ty, ty, &mut constants)) + } + } + + impl<'ast> DeclarationSignature<'ast> { + pub fn specialize( + &self, + values: Vec>, + signature: &ConcreteSignature, + ) -> Result, SpecializationError> { + // we keep track of the value of constants in a map, as a given constant can only have one value + let mut constants = ConcreteGenericsAssignment::default(); + + assert_eq!(self.generics.len(), values.len()); + + let decl_generics = self.generics.iter().map(|g| match g.clone().unwrap() { + Constant::Generic(g) => g, + _ => unreachable!(), + }); + + constants.0.extend( + decl_generics + .zip(values.into_iter()) + .filter_map(|(g, v)| v.map(|v| (g, v as usize))), + ); + + let condition = self + .inputs + .iter() + .chain(self.outputs.iter()) + .zip(signature.inputs.iter().chain(signature.outputs.iter())) + .all(|(decl_ty, ty)| check_type(decl_ty, ty, &mut constants)); + + if constants.0.len() != self.generics.len() { + return Err(SpecializationError); + } + + match condition { + true => Ok(constants), + false => Err(SpecializationError), + } + } + + pub fn get_output_types( + &self, + inputs: Vec>, + ) -> Result>, GenericIdentifier<'ast>> { + // we keep track of the value of constants in a map, as a given constant can only have one value + let mut constants = GenericsAssignment::default(); + + // fill the map with the inputs + let _ = self + .inputs + .iter() + .zip(inputs.iter()) + .all(|(decl_ty, ty)| check_type(decl_ty, ty, &mut constants)); + + // get the outputs from the map + self.outputs + .clone() + .into_iter() + .map(|t| specialize_type(t, &constants)) + .collect::>() + } + } + + pub fn try_from_g_signature, U>( + t: GSignature, + ) -> Result, SpecializationError> { + Ok(GSignature { + generics: t + .generics + .into_iter() + .map(|g| match g { + Some(g) => g.try_into().map(Some).map_err(|_| SpecializationError), + None => Ok(None), + }) + .collect::>()?, + inputs: t + .inputs + .into_iter() + .map(try_from_g_type) + .collect::>()?, + outputs: t + .outputs + .into_iter() + .map(try_from_g_type) + .collect::>()?, + }) + } + + impl<'ast, T> TryFrom> for ConcreteSignature { + type Error = SpecializationError; + + fn try_from(s: Signature<'ast, T>) -> Result { + try_from_g_signature(s) + } + } + + impl<'ast, T> From for Signature<'ast, T> { + fn from(s: ConcreteSignature) -> Self { + try_from_g_signature(s).unwrap() + } + } + + impl<'ast> From for DeclarationSignature<'ast> { + fn from(s: ConcreteSignature) -> Self { + try_from_g_signature(s).unwrap() + } + } + + impl<'ast, T> From> for Signature<'ast, T> { + fn from(s: DeclarationSignature<'ast>) -> Self { + try_from_g_signature(s).unwrap() + } + } + + impl fmt::Debug for GSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "Signature(inputs: {:?}, outputs: {:?})", - self.inputs, self.outputs + "Signature(generics: {:?}, inputs: {:?}, outputs: {:?})", + self.generics, self.inputs, self.outputs ) } } - impl fmt::Display for Signature { + impl fmt::Display for GSignature { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if !self.generics.is_empty() { + write!( + f, + "<{}>", + self.generics + .iter() + .map(|g| g + .as_ref() + .map(|g| g.to_string()) + .unwrap_or_else(|| '_'.to_string())) + .collect::>() + .join(", ") + )?; + } + write!(f, "(")?; for (i, t) in self.inputs.iter().enumerate() { write!(f, "{}", t)?; @@ -413,7 +1145,28 @@ pub mod signature { } } - impl Signature { + impl GSignature { + pub fn new() -> GSignature { + Self::default() + } + + pub fn generics(mut self, generics: Vec>) -> Self { + self.generics = generics; + self + } + + pub fn inputs(mut self, inputs: Vec>) -> Self { + self.inputs = inputs; + self + } + + pub fn outputs(mut self, outputs: Vec>) -> Self { + self.outputs = outputs; + self + } + } + + impl ConcreteSignature { /// Returns a slug for a signature, with the following encoding: /// i{inputs}o{outputs} where {inputs} and {outputs} each encode a list of types. /// A list of types is encoded by compressing sequences of the same type like so: @@ -424,22 +1177,20 @@ pub mod signature { /// [field, field, bool, field] -> 2fbf /// pub fn to_slug(&self) -> String { - let to_slug = |types| { + let to_slug = |types: &[ConcreteType]| { let mut res = vec![]; for t in types { let len = res.len(); if len == 0 { res.push((1, t)) + } else if res[len - 1].1 == t { + res[len - 1].0 += 1; } else { - if res[len - 1].1 == t { - res[len - 1].0 += 1; - } else { - res.push((1, t)) - } + res.push((1, t)) } } res.into_iter() - .map(|(n, t): (usize, &Type)| { + .map(|(n, t): (usize, &ConcreteType)| { let mut r = String::new(); if n > 1 { @@ -456,23 +1207,6 @@ pub mod signature { format!("i{}o{}", to_slug(&self.inputs), to_slug(&self.outputs)) } - - pub fn new() -> Signature { - Signature { - inputs: vec![], - outputs: vec![], - } - } - - pub fn inputs(mut self, inputs: Vec) -> Self { - self.inputs = inputs; - self - } - - pub fn outputs(mut self, outputs: Vec) -> Self { - self.outputs = outputs; - self - } } #[cfg(test)] @@ -481,29 +1215,61 @@ pub mod signature { #[test] fn signature() { - let s = Signature::new() - .inputs(vec![Type::FieldElement, Type::Boolean]) - .outputs(vec![Type::Boolean]); + let s = ConcreteSignature::new() + .inputs(vec![ConcreteType::FieldElement, ConcreteType::Boolean]) + .outputs(vec![ConcreteType::Boolean]); assert_eq!(s.to_string(), String::from("(field, bool) -> bool")); } + #[test] + fn signature_equivalence() { + let generic = DeclarationSignature::new() + .generics(vec![Some("P".into())]) + .inputs(vec![DeclarationType::array(DeclarationArrayType::new( + DeclarationType::FieldElement, + "P".into(), + ))]); + let specialized = DeclarationSignature::new().inputs(vec![DeclarationType::array( + DeclarationArrayType::new(DeclarationType::FieldElement, 3u32.into()), + )]); + + assert_eq!(generic, specialized); + assert_eq!( + { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + generic.hash(&mut hasher); + hasher.finish() + }, + { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + specialized.hash(&mut hasher); + hasher.finish() + } + ); + assert_eq!( + generic.partial_cmp(&specialized), + Some(std::cmp::Ordering::Equal) + ); + assert_eq!(generic.cmp(&specialized), std::cmp::Ordering::Equal); + } + #[test] fn slug_0() { - let s = Signature::new().inputs(vec![]).outputs(vec![]); + let s = ConcreteSignature::new().inputs(vec![]).outputs(vec![]); assert_eq!(s.to_slug(), String::from("io")); } #[test] fn slug_1() { - let s = Signature::new() - .inputs(vec![Type::FieldElement, Type::Boolean]) + let s = ConcreteSignature::new() + .inputs(vec![ConcreteType::FieldElement, ConcreteType::Boolean]) .outputs(vec![ - Type::FieldElement, - Type::FieldElement, - Type::Boolean, - Type::FieldElement, + ConcreteType::FieldElement, + ConcreteType::FieldElement, + ConcreteType::Boolean, + ConcreteType::FieldElement, ]); assert_eq!(s.to_slug(), String::from("ifbo2fbf")); @@ -511,23 +1277,27 @@ pub mod signature { #[test] fn slug_2() { - let s = Signature::new() + let s = ConcreteSignature::new() .inputs(vec![ - Type::FieldElement, - Type::FieldElement, - Type::FieldElement, + ConcreteType::FieldElement, + ConcreteType::FieldElement, + ConcreteType::FieldElement, ]) - .outputs(vec![Type::FieldElement, Type::Boolean, Type::FieldElement]); + .outputs(vec![ + ConcreteType::FieldElement, + ConcreteType::Boolean, + ConcreteType::FieldElement, + ]); assert_eq!(s.to_slug(), String::from("i3fofbf")); } #[test] fn array_slug() { - let s = Signature::new() + let s = ConcreteSignature::new() .inputs(vec![ - Type::array(Type::FieldElement, 42), - Type::array(Type::FieldElement, 21), + ConcreteType::array((ConcreteType::FieldElement, 42usize)), + ConcreteType::array((ConcreteType::FieldElement, 21usize)), ]) .outputs(vec![]); @@ -542,7 +1312,17 @@ mod tests { #[test] fn array() { - let t = Type::Array(ArrayType::new(Type::FieldElement, 42)); + let t = ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 42usize)); assert_eq!(t.get_primitive_count(), 42); } + + #[test] + fn array_display() { + // field[1][2] + let t = ConcreteType::Array(ConcreteArrayType::new( + ConcreteType::Array(ConcreteArrayType::new(ConcreteType::FieldElement, 2usize)), + 1usize, + )); + assert_eq!(format!("{}", t), "field[1][2]"); + } } diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index 18d7df79..880d8312 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -1,40 +1,75 @@ -use crate::typed_absy::types::{FunctionKey, UBitwidth}; +use crate::typed_absy::types::UBitwidth; use crate::typed_absy::*; +use std::ops::{Add, Div, Mul, Not, Rem, Sub}; use zokrates_field::Field; type Bitwidth = usize; -impl<'ast, T: Field> UExpression<'ast, T> { - pub fn add(self, other: Self) -> UExpression<'ast, T> { +impl<'ast, T> Add for UExpression<'ast, T> { + type Output = Self; + + fn add(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); - UExpressionInner::Add(box self, box other).annotate(bitwidth) - } - pub fn sub(self, other: Self) -> UExpression<'ast, T> { + match (self.as_inner(), other.as_inner()) { + (UExpressionInner::Value(0), _) => other, + (_, UExpressionInner::Value(0)) => self, + _ => UExpressionInner::Add(box self, box other).annotate(bitwidth), + } + } +} + +impl<'ast, T> Sub for UExpression<'ast, T> { + type Output = Self; + + fn sub(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); UExpressionInner::Sub(box self, box other).annotate(bitwidth) } +} - pub fn mult(self, other: Self) -> UExpression<'ast, T> { +impl<'ast, T> Mul for UExpression<'ast, T> { + type Output = Self; + + fn mul(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); UExpressionInner::Mult(box self, box other).annotate(bitwidth) } +} - pub fn div(self, other: Self) -> UExpression<'ast, T> { +impl<'ast, T> Div for UExpression<'ast, T> { + type Output = Self; + + fn div(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); UExpressionInner::Div(box self, box other).annotate(bitwidth) } +} - pub fn rem(self, other: Self) -> UExpression<'ast, T> { +impl<'ast, T> Rem for UExpression<'ast, T> { + type Output = Self; + + fn rem(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); UExpressionInner::Rem(box self, box other).annotate(bitwidth) } +} +impl<'ast, T> Not for UExpression<'ast, T> { + type Output = Self; + + fn not(self) -> UExpression<'ast, T> { + let bitwidth = self.bitwidth; + UExpressionInner::Not(box self).annotate(bitwidth) + } +} + +impl<'ast, T: Field> UExpression<'ast, T> { pub fn xor(self, other: Self) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; assert_eq!(bitwidth, other.bitwidth); @@ -79,6 +114,12 @@ impl<'ast, T: Field> UExpression<'ast, T> { assert_eq!(by.bitwidth, UBitwidth::B32); UExpressionInner::RightShift(box self, box by).annotate(bitwidth) } + + pub fn floor_sub(self, other: Self) -> UExpression<'ast, T> { + let bitwidth = self.bitwidth; + assert_eq!(bitwidth, other.bitwidth); + UExpressionInner::FloorSub(box self, box other).annotate(bitwidth) + } } impl<'ast, T: Field> From for UExpressionInner<'ast, T> { @@ -106,12 +147,40 @@ pub struct UExpression<'ast, T> { pub inner: UExpressionInner<'ast, T>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +impl<'ast, T> From for UExpression<'ast, T> { + fn from(u: u32) -> Self { + UExpressionInner::Value(u as u128).annotate(UBitwidth::B32) + } +} + +impl<'ast, T> From for UExpression<'ast, T> { + fn from(u: u16) -> Self { + UExpressionInner::Value(u as u128).annotate(UBitwidth::B16) + } +} + +impl<'ast, T> From for UExpression<'ast, T> { + fn from(u: u8) -> Self { + UExpressionInner::Value(u as u128).annotate(UBitwidth::B8) + } +} + +impl<'ast, T> PartialEq for UExpression<'ast, T> { + fn eq(&self, other: &usize) -> bool { + match self.as_inner() { + UExpressionInner::Value(v) => *v == *other as u128, + _ => true, + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum UExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(u128), Add(Box>, Box>), Sub(Box>, Box>), + FloorSub(Box>, Box>), Mult(Box>, Box>), Div(Box>, Box>), Rem(Box>, Box>), @@ -121,19 +190,20 @@ pub enum UExpressionInner<'ast, T> { Not(Box>), Neg(Box>), Pos(Box>), + FunctionCall( + DeclarationFunctionKey<'ast>, + Vec>>, + Vec>, + ), LeftShift(Box>, Box>), RightShift(Box>, Box>), - FunctionCall(FunctionKey<'ast>, Vec>), IfElse( Box>, Box>, Box>, ), Member(Box>, MemberId), - Select( - Box>, - Box>, - ), + Select(Box>, Box>), } impl<'ast, T> UExpressionInner<'ast, T> { @@ -159,7 +229,7 @@ pub fn bitwidth(a: u128) -> Bitwidth { (128 - a.leading_zeros()) as Bitwidth } -impl<'ast, T: Field> UExpression<'ast, T> { +impl<'ast, T> UExpression<'ast, T> { pub fn bitwidth(&self) -> UBitwidth { self.bitwidth } diff --git a/zokrates_core/src/typed_absy/variable.rs b/zokrates_core/src/typed_absy/variable.rs index ca40cca9..9151dfdc 100644 --- a/zokrates_core/src/typed_absy/variable.rs +++ b/zokrates_core/src/typed_absy/variable.rs @@ -1,76 +1,102 @@ -use crate::typed_absy::types::Type; -use crate::typed_absy::types::{StructType, UBitwidth}; +use crate::typed_absy::types::{Constant, GStructType, UBitwidth}; +use crate::typed_absy::types::{GType, SpecializationError}; use crate::typed_absy::Identifier; +use crate::typed_absy::UExpression; +use crate::typed_absy::{TryFrom, TryInto}; use std::fmt; #[derive(Clone, PartialEq, Hash, Eq)] -pub struct Variable<'ast> { +pub struct GVariable<'ast, S> { pub id: Identifier<'ast>, - pub _type: Type, + pub _type: GType, } -impl<'ast> Variable<'ast> { - pub fn field_element>>(id: I) -> Variable<'ast> { - Self::with_id_and_type(id, Type::FieldElement) +pub type DeclarationVariable<'ast> = GVariable<'ast, Constant<'ast>>; +pub type ConcreteVariable<'ast> = GVariable<'ast, usize>; +pub type Variable<'ast, T> = GVariable<'ast, UExpression<'ast, T>>; + +impl<'ast, T> TryFrom> for ConcreteVariable<'ast> { + type Error = SpecializationError; + + fn try_from(v: Variable<'ast, T>) -> Result { + let _type = v._type.try_into()?; + + Ok(Self { _type, id: v.id }) + } +} + +// impl<'ast> TryFrom> for ConcreteVariable<'ast> { +// type Error = SpecializationError; + +// fn try_from(v: DeclarationVariable<'ast>) -> Result { +// let _type = v._type.try_into()?; + +// Ok(Self { _type, id: v.id }) +// } +// } + +impl<'ast, T> From> for Variable<'ast, T> { + fn from(v: ConcreteVariable<'ast>) -> Self { + let _type = v._type.into(); + + Self { _type, id: v.id } + } +} + +impl<'ast, T> From> for Variable<'ast, T> { + fn from(v: DeclarationVariable<'ast>) -> Self { + let _type = v._type.into(); + + Self { _type, id: v.id } + } +} + +impl<'ast, S: Clone> GVariable<'ast, S> { + pub fn field_element>>(id: I) -> Self { + Self::with_id_and_type(id, GType::FieldElement) } - pub fn boolean>>(id: I) -> Variable<'ast> { - Self::with_id_and_type(id, Type::Boolean) + pub fn boolean>>(id: I) -> Self { + Self::with_id_and_type(id, GType::Boolean) } - pub fn uint>, W: Into>( - id: I, - bitwidth: W, - ) -> Variable<'ast> { - Self::with_id_and_type(id, Type::uint(bitwidth)) + pub fn uint>, W: Into>(id: I, bitwidth: W) -> Self { + Self::with_id_and_type(id, GType::uint(bitwidth)) } #[cfg(test)] - pub fn field_array>>(id: I, size: usize) -> Variable<'ast> { - Self::array(id, Type::FieldElement, size) + pub fn field_array>>(id: I, size: S) -> Self { + Self::array(id, GType::FieldElement, size) } - pub fn array>>(id: I, ty: Type, size: usize) -> Variable<'ast> { - Self::with_id_and_type(id, Type::array(ty, size)) + pub fn array>, U: Into>(id: I, ty: GType, size: U) -> Self { + Self::with_id_and_type(id, GType::array((ty, size.into()))) } - pub fn struc>>(id: I, ty: StructType) -> Variable<'ast> { - Self::with_id_and_type(id, Type::Struct(ty)) + pub fn struc>>(id: I, ty: GStructType) -> Self { + Self::with_id_and_type(id, GType::Struct(ty)) } - pub fn with_id_and_type>>(id: I, _type: Type) -> Variable<'ast> { - Variable { + pub fn with_id_and_type>>(id: I, _type: GType) -> Self { + GVariable { id: id.into(), _type, } } - pub fn get_type(&self) -> Type { + pub fn get_type(&self) -> GType { self._type.clone() } } -impl<'ast> fmt::Display for Variable<'ast> { +impl<'ast, S: fmt::Display> fmt::Display for GVariable<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} {}", self._type, self.id,) } } -impl<'ast> fmt::Debug for Variable<'ast> { +impl<'ast, S: fmt::Debug> fmt::Debug for GVariable<'ast, S> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "Variable(type: {:?}, id: {:?})", self._type, self.id,) } } - -// impl<'ast> From> for Variable<'ast> { -// fn from(v: absy::Variable) -> Variable { -// Variable::with_id_and_type( -// Identifier { -// id: v.id, -// version: 0, -// stack: vec![], -// }, -// v._type, -// ) -// } -// } diff --git a/zokrates_core/src/zir/folder.rs b/zokrates_core/src/zir/folder.rs index 884a668d..f82fbee1 100644 --- a/zokrates_core/src/zir/folder.rs +++ b/zokrates_core/src/zir/folder.rs @@ -9,17 +9,6 @@ pub trait Folder<'ast, T: Field>: Sized { fold_program(self, p) } - fn fold_module(&mut self, p: ZirModule<'ast, T>) -> ZirModule<'ast, T> { - fold_module(self, p) - } - - fn fold_function_symbol( - &mut self, - s: ZirFunctionSymbol<'ast, T>, - ) -> ZirFunctionSymbol<'ast, T> { - fold_function_symbol(self, s) - } - fn fold_function(&mut self, f: ZirFunction<'ast, T>) -> ZirFunction<'ast, T> { fold_function(self, f) } @@ -63,14 +52,14 @@ pub trait Folder<'ast, T: Field>: Sized { es: ZirExpressionList<'ast, T>, ) -> ZirExpressionList<'ast, T> { match es { - ZirExpressionList::FunctionCall(id, arguments, types) => { - ZirExpressionList::FunctionCall( - id, + ZirExpressionList::EmbedCall(embed, generics, arguments) => { + ZirExpressionList::EmbedCall( + embed, + generics, arguments .into_iter() .map(|a| self.fold_expression(a)) .collect(), - types, ) } } @@ -101,20 +90,6 @@ pub trait Folder<'ast, T: Field>: Sized { } } -pub fn fold_module<'ast, T: Field, F: Folder<'ast, T>>( - f: &mut F, - p: ZirModule<'ast, T>, -) -> ZirModule<'ast, T> { - ZirModule { - functions: p - .functions - .into_iter() - .map(|(key, fun)| (key, f.fold_function_symbol(fun))) - .collect(), - ..p - } -} - pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, s: ZirStatement<'ast, T>, @@ -170,7 +145,7 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } FieldElementExpression::Pow(box e1, box e2) => { let e1 = f.fold_field_expression(e1); - let e2 = f.fold_field_expression(e2); + let e2 = f.fold_uint_expression(e2); FieldElementExpression::Pow(box e1, box e2) } FieldElementExpression::IfElse(box cond, box cons, box alt) => { @@ -204,25 +179,45 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( let e2 = f.fold_uint_expression(e2); BooleanExpression::UintEq(box e1, box e2) } - BooleanExpression::Lt(box e1, box e2) => { + BooleanExpression::FieldLt(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Lt(box e1, box e2) + BooleanExpression::FieldLt(box e1, box e2) } - BooleanExpression::Le(box e1, box e2) => { + BooleanExpression::FieldLe(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Le(box e1, box e2) + BooleanExpression::FieldLe(box e1, box e2) } - BooleanExpression::Gt(box e1, box e2) => { + BooleanExpression::FieldGt(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Gt(box e1, box e2) + BooleanExpression::FieldGt(box e1, box e2) } - BooleanExpression::Ge(box e1, box e2) => { + BooleanExpression::FieldGe(box e1, box e2) => { let e1 = f.fold_field_expression(e1); let e2 = f.fold_field_expression(e2); - BooleanExpression::Ge(box e1, box e2) + BooleanExpression::FieldGe(box e1, box e2) + } + BooleanExpression::UintLt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintLt(box e1, box e2) + } + BooleanExpression::UintLe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintLe(box e1, box e2) + } + BooleanExpression::UintGt(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintGt(box e1, box e2) + } + BooleanExpression::UintGe(box e1, box e2) => { + let e1 = f.fold_uint_expression(e1); + let e2 = f.fold_uint_expression(e2); + BooleanExpression::UintGe(box e1, box e2) } BooleanExpression::Or(box e1, box e2) => { let e1 = f.fold_boolean_expression(e1); @@ -358,26 +353,11 @@ pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( } } -pub fn fold_function_symbol<'ast, T: Field, F: Folder<'ast, T>>( - f: &mut F, - s: ZirFunctionSymbol<'ast, T>, -) -> ZirFunctionSymbol<'ast, T> { - match s { - ZirFunctionSymbol::Here(fun) => ZirFunctionSymbol::Here(f.fold_function(fun)), - there => there, // by default, do not fold modules recursively - } -} - pub fn fold_program<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, p: ZirProgram<'ast, T>, ) -> ZirProgram<'ast, T> { ZirProgram { - modules: p - .modules - .into_iter() - .map(|(module_id, module)| (module_id, f.fold_module(module))) - .collect(), - main: p.main, + main: f.fold_function(p.main), } } diff --git a/zokrates_core/src/zir/from_typed.rs b/zokrates_core/src/zir/from_typed.rs index 30309e0b..3a756b2e 100644 --- a/zokrates_core/src/zir/from_typed.rs +++ b/zokrates_core/src/zir/from_typed.rs @@ -1,33 +1,28 @@ use crate::typed_absy; use crate::zir; -impl<'ast> From> for zir::types::FunctionKey<'ast> { - fn from(k: typed_absy::types::FunctionKey<'ast>) -> zir::types::FunctionKey<'ast> { - zir::types::FunctionKey { - id: k.id, - signature: k.signature.into(), - } - } -} -impl From for zir::types::Signature { - fn from(s: typed_absy::types::Signature) -> zir::types::Signature { +impl From for zir::types::Signature { + fn from(s: typed_absy::types::ConcreteSignature) -> zir::types::Signature { zir::types::Signature { - inputs: s.inputs.into_iter().flat_map(|t| from_type(t)).collect(), - outputs: s.outputs.into_iter().flat_map(|t| from_type(t)).collect(), + inputs: s.inputs.into_iter().flat_map(from_type).collect(), + outputs: s.outputs.into_iter().flat_map(from_type).collect(), } } } -fn from_type(t: typed_absy::types::Type) -> Vec { +fn from_type(t: typed_absy::types::ConcreteType) -> Vec { match t { - typed_absy::Type::FieldElement => vec![zir::Type::FieldElement], - typed_absy::Type::Boolean => vec![zir::Type::Boolean], - typed_absy::Type::Uint(bitwidth) => vec![zir::Type::uint(bitwidth.to_usize())], - typed_absy::Type::Array(array_type) => { + typed_absy::types::ConcreteType::Int => unreachable!(), + typed_absy::types::ConcreteType::FieldElement => vec![zir::Type::FieldElement], + typed_absy::types::ConcreteType::Boolean => vec![zir::Type::Boolean], + typed_absy::types::ConcreteType::Uint(bitwidth) => { + vec![zir::Type::uint(bitwidth.to_usize())] + } + typed_absy::types::ConcreteType::Array(array_type) => { let inner = from_type(*array_type.ty); (0..array_type.size).flat_map(|_| inner.clone()).collect() } - typed_absy::Type::Struct(members) => members + typed_absy::types::ConcreteType::Struct(members) => members .into_iter() .flat_map(|struct_member| from_type(*struct_member.ty)) .collect(), diff --git a/zokrates_core/src/zir/mod.rs b/zokrates_core/src/zir/mod.rs index 6464951e..6fb06c47 100644 --- a/zokrates_core/src/zir/mod.rs +++ b/zokrates_core/src/zir/mod.rs @@ -10,11 +10,9 @@ pub use self::parameter::Parameter; pub use self::types::Type; pub use self::variable::Variable; pub use crate::zir::uint::{ShouldReduce, UExpression, UExpressionInner, UMetadata}; -use std::path::PathBuf; use crate::embed::FlatEmbed; -use crate::zir::types::{FunctionKey, Signature}; -use std::collections::HashMap; +use crate::zir::types::Signature; use std::convert::TryFrom; use std::fmt; use zokrates_field::Field; @@ -23,115 +21,17 @@ pub use self::folder::Folder; pub use self::identifier::{Identifier, SourceIdentifier}; -/// An identifier for a `ZirModule`. Typically a path or uri. -pub type ZirModuleId = PathBuf; - -/// A collection of `ZirModule`s -pub type ZirModules<'ast, T> = HashMap>; - -/// A collection of `ZirFunctionSymbol`s -/// # Remarks -/// * It is the role of the semantic checker to make sure there are no duplicates for a given `FunctionKey` -/// in a given `ZirModule`, hence the use of a HashMap -pub type ZirFunctionSymbols<'ast, T> = HashMap, ZirFunctionSymbol<'ast, T>>; - /// A typed program as a collection of modules, one of them being the main #[derive(PartialEq, Debug)] pub struct ZirProgram<'ast, T> { - pub modules: ZirModules<'ast, T>, - pub main: ZirModuleId, + pub main: ZirFunction<'ast, T>, } impl<'ast, T: fmt::Display> fmt::Display for ZirProgram<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - for (module_id, module) in &self.modules { - writeln!( - f, - "| {}: |{}", - module_id.display(), - if *module_id == self.main { - "<---- main" - } else { - "" - } - )?; - writeln!(f, "{}", "-".repeat(100))?; - writeln!(f, "{}", module)?; - writeln!(f, "{}", "-".repeat(100))?; - writeln!(f, "")?; - } - write!(f, "") + write!(f, "{}", self.main) } } - -/// A typed program as a collection of functions. Types have been resolved during semantic checking. -#[derive(PartialEq, Clone)] -pub struct ZirModule<'ast, T> { - /// Functions of the program - pub functions: ZirFunctionSymbols<'ast, T>, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum ZirFunctionSymbol<'ast, T> { - Here(ZirFunction<'ast, T>), - There(FunctionKey<'ast>, ZirModuleId), - Flat(FlatEmbed), -} - -impl<'ast, T> ZirFunctionSymbol<'ast, T> { - pub fn signature<'a>(&'a self, modules: &'a ZirModules) -> Signature { - match self { - ZirFunctionSymbol::Here(f) => f.signature.clone(), - ZirFunctionSymbol::There(key, module_id) => modules - .get(module_id) - .unwrap() - .functions - .get(key) - .unwrap() - .signature(&modules) - .clone(), - ZirFunctionSymbol::Flat(flat_fun) => flat_fun.signature().into(), - } - } -} - -impl<'ast, T: fmt::Display> fmt::Display for ZirModule<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let res = self - .functions - .iter() - .map(|(key, symbol)| match symbol { - ZirFunctionSymbol::Here(ref function) => format!("def {}{}", key.id, function), - ZirFunctionSymbol::There(ref fun_key, ref module_id) => format!( - "import {} from \"{}\" as {} // with signature {}", - fun_key.id, - module_id.display(), - key.id, - key.signature - ), - ZirFunctionSymbol::Flat(ref flat_fun) => { - format!("def {}{}:\n\t// hidden", key.id, flat_fun.signature()) - } - }) - .collect::>(); - write!(f, "{}", res.join("\n")) - } -} - -impl<'ast, T: fmt::Debug> fmt::Debug for ZirModule<'ast, T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "module(\n\tfunctions:\n\t\t{}\n)", - self.functions - .iter() - .map(|x| format!("{:?}", x)) - .collect::>() - .join("\n\t\t") - ) - } -} - /// A typed function #[derive(Clone, PartialEq)] pub struct ZirFunction<'ast, T> { @@ -333,15 +233,7 @@ pub trait MultiTyped { #[derive(Clone, PartialEq, Hash, Eq)] pub enum ZirExpressionList<'ast, T> { - FunctionCall(FunctionKey<'ast>, Vec>, Vec), -} - -impl<'ast, T: Field> MultiTyped for ZirExpressionList<'ast, T> { - fn get_types(&self) -> &Vec { - match *self { - ZirExpressionList::FunctionCall(_, _, ref types) => types, - } - } + EmbedCall(FlatEmbed, Vec, Vec>), } /// An expression of type `field` @@ -367,7 +259,7 @@ pub enum FieldElementExpression<'ast, T> { ), Pow( Box>, - Box>, + Box>, ), IfElse( Box>, @@ -381,14 +273,26 @@ pub enum FieldElementExpression<'ast, T> { pub enum BooleanExpression<'ast, T> { Identifier(Identifier<'ast>), Value(bool), - Lt( + FieldLt( Box>, Box>, ), - Le( + FieldLe( Box>, Box>, ), + FieldGe( + Box>, + Box>, + ), + FieldGt( + Box>, + Box>, + ), + UintLt(Box>, Box>), + UintLe(Box>, Box>), + UintGe(Box>, Box>), + UintGt(Box>, Box>), FieldEq( Box>, Box>, @@ -398,14 +302,6 @@ pub enum BooleanExpression<'ast, T> { Box>, ), UintEq(Box>, Box>), - Ge( - Box>, - Box>, - ), - Gt( - Box>, - Box>, - ), Or( Box>, Box>, @@ -538,13 +434,17 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { BooleanExpression::Identifier(ref var) => write!(f, "{}", var), - BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), - BooleanExpression::Le(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), + BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), + BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), + BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), - BooleanExpression::Ge(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), - BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs), BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), @@ -590,8 +490,24 @@ impl<'ast, T: fmt::Debug> fmt::Debug for FieldElementExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for ZirExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ZirExpressionList::FunctionCall(ref key, ref p, _) => { - write!(f, "{}(", key.id,)?; + ZirExpressionList::EmbedCall(ref embed, ref generics, ref p) => { + write!( + f, + "{}{}(", + embed.id(), + if generics.is_empty() { + "".into() + } else { + format!( + "::<{}>", + generics + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + ) + } + )?; for (i, param) in p.iter().enumerate() { write!(f, "{}", param)?; if i < p.len() - 1 { @@ -607,8 +523,8 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirExpressionList<'ast, T> { impl<'ast, T: fmt::Debug> fmt::Debug for ZirExpressionList<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ZirExpressionList::FunctionCall(ref i, ref p, _) => { - write!(f, "FunctionCall({:?}, (", i)?; + ZirExpressionList::EmbedCall(ref embed, ref generics, ref p) => { + write!(f, "EmbedCall({:?}, {:?}, (", generics, embed)?; f.debug_list().entries(p.iter()).finish()?; write!(f, ")") } diff --git a/zokrates_core/src/zir/types.rs b/zokrates_core/src/zir/types.rs index ac5aa6bd..48626022 100644 --- a/zokrates_core/src/zir/types.rs +++ b/zokrates_core/src/zir/types.rs @@ -1,8 +1,6 @@ use serde::{Deserialize, Serialize}; use std::fmt; -pub type Identifier<'ast> = &'ast str; - pub type MemberId = String; #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] @@ -92,44 +90,13 @@ impl Type { } } -pub type FunctionIdentifier<'ast> = &'ast str; - -#[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub struct FunctionKey<'ast> { - pub id: FunctionIdentifier<'ast>, - pub signature: Signature, -} - -impl<'ast> FunctionKey<'ast> { - pub fn with_id>>(id: S) -> Self { - FunctionKey { - id: id.into(), - signature: Signature::new(), - } - } - - pub fn signature(mut self, signature: Signature) -> Self { - self.signature = signature; - self - } - - pub fn id>>(mut self, id: S) -> Self { - self.id = id.into(); - self - } - - pub fn to_slug(&self) -> String { - format!("{}_{}", self.id, self.signature.to_slug()) - } -} - pub use self::signature::Signature; pub mod signature { use super::*; use std::fmt; - #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd)] + #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Ord, PartialOrd, Default)] pub struct Signature { pub inputs: Vec, pub outputs: Vec, @@ -189,12 +156,10 @@ pub mod signature { let len = res.len(); if len == 0 { res.push((1, t)) + } else if res[len - 1].1 == t { + res[len - 1].0 += 1; } else { - if res[len - 1].1 == t { - res[len - 1].0 += 1; - } else { - res.push((1, t)) - } + res.push((1, t)) } } res.into_iter() @@ -217,10 +182,7 @@ pub mod signature { } pub fn new() -> Signature { - Signature { - inputs: vec![], - outputs: vec![], - } + Signature::default() } pub fn inputs(mut self, inputs: Vec) -> Self { diff --git a/zokrates_core/src/zir/uint.rs b/zokrates_core/src/zir/uint.rs index 4f1e4a0d..660e4a54 100644 --- a/zokrates_core/src/zir/uint.rs +++ b/zokrates_core/src/zir/uint.rs @@ -82,6 +82,12 @@ impl<'ast, T: Field> From<&'ast str> for UExpressionInner<'ast, T> { } } +impl<'ast, T> From for UExpression<'ast, T> { + fn from(u: u32) -> Self { + UExpressionInner::Value(u as u128).annotate(UBitwidth::B32) + } +} + #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub enum ShouldReduce { Unknown, @@ -101,10 +107,7 @@ impl ShouldReduce { } pub fn is_unknown(&self) -> bool { - match self { - ShouldReduce::Unknown => true, - _ => false, - } + *self == ShouldReduce::Unknown } // we can always enable a reduction diff --git a/zokrates_core/tests/out_of_range.rs b/zokrates_core/tests/out_of_range.rs index a2142622..ca7f43ab 100644 --- a/zokrates_core/tests/out_of_range.rs +++ b/zokrates_core/tests/out_of_range.rs @@ -36,6 +36,6 @@ fn out_of_range() { let interpreter = Interpreter::try_out_of_range(); assert!(interpreter - .execute(&res.prog(), &vec![Bn128Field::from(10000)]) + .execute(&res.prog(), &[Bn128Field::from(10000)]) .is_err()); } diff --git a/zokrates_core/tests/wasm.rs b/zokrates_core/tests/wasm.rs index 79e6e7b9..6554beb0 100644 --- a/zokrates_core/tests/wasm.rs +++ b/zokrates_core/tests/wasm.rs @@ -29,7 +29,7 @@ fn generate_proof() { let interpreter = Interpreter::default(); let witness = interpreter - .execute(&program, &vec![Bn128Field::from(42)]) + .execute(&program, &[Bn128Field::from(42)]) .unwrap(); let keypair = >::setup(program.clone()); diff --git a/zokrates_core_test/tests/tests/fact_up_to_4.zok b/zokrates_core_test/tests/tests/fact_up_to_4.zok index d33b19c0..299cd995 100644 --- a/zokrates_core_test/tests/tests/fact_up_to_4.zok +++ b/zokrates_core_test/tests/tests/fact_up_to_4.zok @@ -1,8 +1,10 @@ +import "utils/casts/u32_to_field" as to_field + def main(field x) -> field: field f = 1 field counter = 0 - for field i in 1..5 do - f = if counter == x then f else f * i fi + for u32 i in 1..5 do + f = if counter == x then f else f * to_field(i) fi counter = if counter == x then counter else counter + 1 fi endfor return f \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/cache.json b/zokrates_core_test/tests/tests/generics/cache.json new file mode 100644 index 00000000..749eefef --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/cache.json @@ -0,0 +1,17 @@ +{ + "curves": ["Bn128", "Bls12_381"], + "tests": [ + { + "input": { + "values": [ + ] + }, + "output": { + "Ok": { + "values": [ + ] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/cache.zok b/zokrates_core_test/tests/tests/generics/cache.zok new file mode 100644 index 00000000..2a2f1751 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/cache.zok @@ -0,0 +1,7 @@ +def id() -> u32: + return N + +def main(): + assert(id::<5>() == 5) + assert(id::<6>() == 6) + return \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/call.json b/zokrates_core_test/tests/tests/generics/call.json new file mode 100644 index 00000000..4e45eccc --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/call.json @@ -0,0 +1,4 @@ +{ + "curves": ["Bn128", "Bls12_381"], + "tests": [] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/call.zok b/zokrates_core_test/tests/tests/generics/call.zok new file mode 100644 index 00000000..a0e6aa05 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/call.zok @@ -0,0 +1,8 @@ +def foo(field[T] b) -> field: + return 1 + +def bar(field[T] b) -> field: + return foo(b) + +def main(field[3] a) -> field: + return foo(a) + bar(a) \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/embed.json b/zokrates_core_test/tests/tests/generics/embed.json new file mode 100644 index 00000000..4e45eccc --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/embed.json @@ -0,0 +1,4 @@ +{ + "curves": ["Bn128", "Bls12_381"], + "tests": [] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/embed.zok b/zokrates_core_test/tests/tests/generics/embed.zok new file mode 100644 index 00000000..88160bba --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/embed.zok @@ -0,0 +1,5 @@ +import "EMBED/unpack" as unpack + +def main(field x): + bool[1] bits = unpack(x) + return \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/multidef.json b/zokrates_core_test/tests/tests/generics/multidef.json new file mode 100644 index 00000000..4f0d5931 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/multidef.json @@ -0,0 +1,23 @@ +{ + "curves": ["Bn128", "Bls12_381"], + "tests": [ + { + "input": { + "values": [ + "1", + "2", + "3" + ] + }, + "output": { + "Ok": { + "values": [ + "1", + "2", + "3" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/multidef.zok b/zokrates_core_test/tests/tests/generics/multidef.zok new file mode 100644 index 00000000..dbf3bd3f --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/multidef.zok @@ -0,0 +1,6 @@ +def foo(field[T] b) -> field[T]: + return b + +def main(field[3] a) -> field[3]: + field[3] res = foo(a) + return res \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/same_parameter_name.json b/zokrates_core_test/tests/tests/generics/same_parameter_name.json new file mode 100644 index 00000000..b97905d7 --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/same_parameter_name.json @@ -0,0 +1,22 @@ +{ + "curves": ["Bn128", "Bls12_381"], + "tests": [ + { + "input": { + "values": [ + "1", + "2", + "3" + ] + }, + "output": { + "Ok": { + "values": [ + "1", + "2" + ] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/generics/same_parameter_name.zok b/zokrates_core_test/tests/tests/generics/same_parameter_name.zok new file mode 100644 index 00000000..182637fb --- /dev/null +++ b/zokrates_core_test/tests/tests/generics/same_parameter_name.zok @@ -0,0 +1,10 @@ +def foo(field[N] x) -> field[N]: + return x + +def bar(field[N] x) -> field[N]: + field[N] r = x + return r + +def main(field[3] x) -> field[2]: + field[2] z = foo(x)[0..2] + return bar(z) \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/neg_pos.zok b/zokrates_core_test/tests/tests/neg_pos.zok index e101328b..2322df36 100644 --- a/zokrates_core_test/tests/tests/neg_pos.zok +++ b/zokrates_core_test/tests/tests/neg_pos.zok @@ -10,6 +10,6 @@ def main(field x, field y, u8 z, u8 t) -> (field[4], u8[4]): u8 h = z - + t // should parse to sub(pos) assert(-0x00 == 0x00) - assert(-0 == 0) + assert(-0f == 0) return [a, b, c, d], [e, f, g, h] \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/nested_loop.zok b/zokrates_core_test/tests/tests/nested_loop.zok index 2d00f1a0..00ee862f 100644 --- a/zokrates_core_test/tests/tests/nested_loop.zok +++ b/zokrates_core_test/tests/tests/nested_loop.zok @@ -2,24 +2,24 @@ def main(field[4] values) -> (field, field, field): field res0 = 1 field res1 = 0 - field counter = 0 + u32 counter = 0 - for field i in 0..4 do - for field j in i..4 do + for u32 i in 0..4 do + for u32 j in i..4 do counter = counter + 1 res0 = res0 * (values[i] + values[j]) endfor endfor - for field i in 0..counter do + for u32 i in 0..counter do res1 = res1 + 1 endfor field res2 = 0 - field i = 0 - for field i in i..5 do + u32 i = 0 + for u32 i in i..5 do i = 5 - for field i in 0..i do + for u32 i in 0..i do res2 = res2 + 1 endfor endfor diff --git a/zokrates_core_test/tests/tests/precedence.zok b/zokrates_core_test/tests/tests/precedence.zok index fcc1e9e8..fb2edcef 100644 --- a/zokrates_core_test/tests/tests/precedence.zok +++ b/zokrates_core_test/tests/tests/precedence.zok @@ -4,12 +4,12 @@ def main(): assert(7 == 2 ** 2 * 2 - 1) assert(3 == 2 ** 2 / 2 + 1) - field a = if 3 == 2 ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators - field b = if 3 == 3 && 4 < 5 then 1 else 0 fi // checks precedence of boolean operators - field c = if 4 < 5 && 3 == 3 then 1 else 0 fi - field d = if 4 > 5 && 2 >= 1 || 1 == 1 then 1 else 0 fi - field e = if 2 >= 1 && 4 > 5 || 1 == 1 then 1 else 0 fi - field f = if 1 < 2 && false || 4 < 5 && 2 >= 1 then 1 else 0 fi + field a = if 3f == 2f ** 2 / 2 + 1 && true then 1 else 0 fi // combines arithmetic with boolean operators + field b = if 3f == 3f && 4f < 5f then 1 else 0 fi // checks precedence of boolean operators + field c = if 4f < 5f && 3f == 3f then 1 else 0 fi + field d = if 4f > 5f && 2f >= 1f || 1f == 1f then 1 else 0 fi + field e = if 2f >= 1f && 4f > 5f || 1f == 1f then 1 else 0 fi + field f = if 1f < 2f && false || 4f < 5f && 2f >= 1f then 1 else 0 fi assert(0x00 ^ 0x00 == 0x00) diff --git a/zokrates_core_test/tests/tests/uint/add_loop.zok b/zokrates_core_test/tests/tests/uint/add_loop.zok index 53a0c19b..cb416240 100644 --- a/zokrates_core_test/tests/tests/uint/add_loop.zok +++ b/zokrates_core_test/tests/tests/uint/add_loop.zok @@ -1,7 +1,7 @@ // 32 constraints for input constraining def main(u32 a) -> u32: u32 res = 0x00000000 - for field i in 0..10 do + for u32 i in 0..10 do res = res + a endfor return res // 42 constraints (decomposing on 42 bits) \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/uint/extend.zok b/zokrates_core_test/tests/tests/uint/extend.zok index 82cd4d77..f50f328a 100644 --- a/zokrates_core_test/tests/tests/uint/extend.zok +++ b/zokrates_core_test/tests/tests/uint/extend.zok @@ -43,7 +43,7 @@ def right_rotate_25(u32 e) -> u32: // already paying 33 * 64 = 2112 constraints for input constraining. in sha this is done only once. So this is just ~ 152 constraints def main(u32[64] w) -> u32: - field i = 16 - u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 0x00000003) - u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 0x0000000a) + u32 i = 16 + u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 3) + u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 10) return w[i-16] + s0 + w[i-7] + s1 \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/uint/sha256.zok b/zokrates_core_test/tests/tests/uint/sha256.zok index 0f3e0e58..7aabe22c 100644 --- a/zokrates_core_test/tests/tests/uint/sha256.zok +++ b/zokrates_core_test/tests/tests/uint/sha256.zok @@ -41,9 +41,9 @@ def right_rotate_25(u32 e) -> u32: bool[32] b = to_bits(e) return from_bits([...b[7..], ...b[..7]]) -def extend(u32[64] w, field i) -> u32: - u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 0x00000003) - u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 0x0000000a) +def extend(u32[64] w, u32 i) -> u32: + u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 3) + u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 10) return w[i-16] + s0 + w[i-7] + s1 def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32: @@ -82,11 +82,11 @@ def main(u32[1][16] input) -> u32[8]: // assume padding is already done // input = input - for field i in 0..1 do + for u32 i in 0..1 do u32[64] w = [...input[0], ...[0x00000000; 48]] - for field i in 16..64 do + for u32 i in 16..64 do u32 r = extend(w, i) w[i] = r endfor @@ -100,7 +100,7 @@ def main(u32[1][16] input) -> u32[8]: u32 g = h6 u32 h = h7 - for field i in 0..64 do + for u32 i in 0..64 do u32 t1 = temp1(e, f, g, h, k[i], w[i]) diff --git a/zokrates_field/src/bn128.rs b/zokrates_field/src/bn128.rs index cecca736..fae0d1e7 100644 --- a/zokrates_field/src/bn128.rs +++ b/zokrates_field/src/bn128.rs @@ -215,18 +215,6 @@ mod tests { assert_eq!(FieldPrime::from(-12), FieldPrime::from(-85) * res); } - #[test] - fn pow_small() { - assert_eq!( - "8".parse::().unwrap(), - (FieldPrime::from("2").pow(FieldPrime::from("3"))).value - ); - assert_eq!( - "8".parse::().unwrap(), - (FieldPrime::from("2").pow(&FieldPrime::from("3"))).value - ); - } - #[test] fn pow_usize() { assert_eq!( @@ -235,34 +223,6 @@ mod tests { ); } - #[test] - fn pow() { - assert_eq!( - "614787626176508399616".parse::().unwrap(), - (FieldPrime::from("54").pow(FieldPrime::from("12"))).value - ); - assert_eq!( - "614787626176508399616".parse::().unwrap(), - (FieldPrime::from("54").pow(&FieldPrime::from("12"))).value - ); - } - - #[test] - fn pow_negative() { - assert_eq!( - "21888242871839275222246405745257275088548364400416034343686819230535502784513" - .parse::() - .unwrap(), - (FieldPrime::from("-54").pow(FieldPrime::from("11"))).value - ); - assert_eq!( - "21888242871839275222246405745257275088548364400416034343686819230535502784513" - .parse::() - .unwrap(), - (FieldPrime::from("-54").pow(&FieldPrime::from("11"))).value - ); - } - #[test] fn serde_ser_deser() { let serialized = &serialize(&FieldPrime::from("11"), Infinite).unwrap(); diff --git a/zokrates_field/src/lib.rs b/zokrates_field/src/lib.rs index ae0bdaf6..f125580e 100644 --- a/zokrates_field/src/lib.rs +++ b/zokrates_field/src/lib.rs @@ -16,6 +16,7 @@ use num_bigint::BigUint; use num_traits::{CheckedDiv, One, Zero}; use serde::{Deserialize, Serialize}; use std::convert::{From, TryFrom}; +use std::fmt; use std::fmt::{Debug, Display}; use std::hash::Hash; use std::ops::{Add, Div, Mul, Sub}; @@ -44,6 +45,14 @@ pub trait ArkFieldExtensions { fn into_ark(self) -> ::Fr; } +pub struct FieldParseError; + +impl fmt::Debug for FieldParseError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Failed to parse to field element") + } +} + pub trait Field: From + From @@ -60,6 +69,8 @@ pub trait Field: + Ord + Display + Debug + + Default + + Hash + Add + for<'a> Add<&'a Self, Output = Self> + Sub @@ -68,9 +79,8 @@ pub trait Field: + for<'a> Mul<&'a Self, Output = Self> + CheckedDiv + Div + + for<'a> Div<&'a Self, Output = Self> + Pow - + Pow - + for<'a> Pow<&'a Self, Output = Self> + for<'a> Deserialize<'a> + Serialize + num_traits::CheckedAdd @@ -94,8 +104,8 @@ pub trait Field: /// Returns the number of bits required to represent any element of this field type. fn get_required_bits() -> usize; /// Tries to parse a string into this representation - fn try_from_dec_str<'a>(s: &'a str) -> Result; - fn try_from_str(s: &str, radix: u32) -> Result; + fn try_from_dec_str(s: &str) -> Result; + fn try_from_str(s: &str, radix: u32) -> Result; /// Returns a decimal string representing a the member of the equivalence class of this `Field` in Z/pZ /// which lies in [-(p-1)/2, (p-1)/2] fn to_compact_dec_string(&self) -> String; @@ -130,7 +140,7 @@ pub trait Field: mod prime_field { macro_rules! prime_field { ($modulus:expr, $name:expr) => { - use crate::{Field, Pow}; + use crate::{Field, FieldParseError, Pow}; use lazy_static::lazy_static; use num_bigint::{BigInt, BigUint, Sign, ToBigInt}; use num_integer::Integer; @@ -208,11 +218,11 @@ mod prime_field { fn get_required_bits() -> usize { (*P).bits() } - fn try_from_dec_str<'a>(s: &'a str) -> Result { + fn try_from_dec_str(s: &str) -> Result { Self::try_from_str(s, 10) } - fn try_from_str(s: &str, radix: u32) -> Result { - let x = BigInt::parse_bytes(s.as_bytes(), radix).ok_or(())?; + fn try_from_str(s: &str, radix: u32) -> Result { + let x = BigInt::parse_bytes(s.as_bytes(), radix).ok_or(FieldParseError)?; Ok(FieldPrime { value: &x - x.div_floor(&*P) * &*P, }) @@ -337,6 +347,14 @@ mod prime_field { type Output = FieldPrime; fn add(self, other: FieldPrime) -> FieldPrime { + if self.value == BigInt::zero() { + return other; + } + + if other.value == BigInt::zero() { + return self; + } + FieldPrime { value: (self.value + other.value) % &*P, } @@ -347,8 +365,16 @@ mod prime_field { type Output = FieldPrime; fn add(self, other: &FieldPrime) -> FieldPrime { + if self.value == BigInt::zero() { + return other.clone(); + } + + if other.value == BigInt::zero() { + return self; + } + FieldPrime { - value: (self.value + other.value.clone()) % &*P, + value: (self.value + &other.value) % &*P, } } } @@ -368,7 +394,7 @@ mod prime_field { type Output = FieldPrime; fn sub(self, other: &FieldPrime) -> FieldPrime { - let x = self.value - other.value.clone(); + let x = self.value - &other.value; FieldPrime { value: &x - x.div_floor(&*P) * &*P, } @@ -379,6 +405,14 @@ mod prime_field { type Output = FieldPrime; fn mul(self, other: FieldPrime) -> FieldPrime { + if self.value == BigInt::one() { + return other; + } + + if other.value == BigInt::one() { + return self; + } + FieldPrime { value: (self.value * other.value) % &*P, } @@ -389,8 +423,16 @@ mod prime_field { type Output = FieldPrime; fn mul(self, other: &FieldPrime) -> FieldPrime { + if self.value == BigInt::one() { + return other.clone(); + } + + if other.value == BigInt::one() { + return self; + } + FieldPrime { - value: (self.value * other.value.clone()) % &*P, + value: (self.value * &other.value) % &*P, } } } @@ -429,38 +471,6 @@ mod prime_field { } } - impl Pow for FieldPrime { - type Output = FieldPrime; - - fn pow(self, exp: FieldPrime) -> FieldPrime { - let mut res = FieldPrime::one(); - let mut current = FieldPrime::zero(); - loop { - if current >= exp { - return res; - } - res = res * &self; - current = current + FieldPrime::one(); - } - } - } - - impl<'a> Pow<&'a FieldPrime> for FieldPrime { - type Output = FieldPrime; - - fn pow(self, exp: &'a FieldPrime) -> FieldPrime { - let mut res = FieldPrime::one(); - let mut current = FieldPrime::zero(); - loop { - if ¤t >= exp { - return res; - } - res = res * &self; - current = current + FieldPrime::one(); - } - } - } - impl num_traits::CheckedAdd for FieldPrime { fn checked_add(&self, other: &Self) -> Option { let bound = Self::max_unique_value(); @@ -468,7 +478,7 @@ mod prime_field { assert!(self <= &bound); assert!(other <= &bound); - let big_res = self.value.clone() + other.value.clone(); + let big_res = &self.value + &other.value; if big_res > bound.value { None @@ -485,7 +495,7 @@ mod prime_field { assert!(self <= &bound); assert!(other <= &bound); - let big_res = self.value.clone() * other.value.clone(); + let big_res = &self.value * &other.value; // we only go up to 2**(bitwidth - 1) because after that we lose uniqueness of bit decomposition if big_res > bound.value { diff --git a/zokrates_fs_resolver/src/lib.rs b/zokrates_fs_resolver/src/lib.rs index 9da6900e..b38a0b17 100644 --- a/zokrates_fs_resolver/src/lib.rs +++ b/zokrates_fs_resolver/src/lib.rs @@ -37,14 +37,12 @@ impl<'a> Resolver for FileSystemResolver<'a> { // other paths `abc/def` are interpreted relative to the standard library root path let base = match source.components().next() { Some(Component::CurDir) | Some(Component::ParentDir) => { - PathBuf::from(current_location).parent().unwrap().into() + current_location.parent().unwrap().into() } _ => PathBuf::from(self.stdlib_root_path.unwrap_or("")), }; - let path_owned = base - .join(PathBuf::from(import_location.clone())) - .with_extension("zok"); + let path_owned = base.join(import_location.clone()).with_extension("zok"); if !path_owned.is_file() { return Err(io::Error::new( @@ -96,7 +94,7 @@ mod tests { // create a source folder with a zok file let folder = tempfile::tempdir().unwrap(); let dir_path = folder.path().join("dir"); - std::fs::create_dir(dir_path.clone()).unwrap(); + std::fs::create_dir(dir_path).unwrap(); let fs_resolver = FileSystemResolver::default(); let res = fs_resolver.resolve(".".into(), "./dir/".into()); @@ -155,7 +153,7 @@ mod tests { let stdlib_root_path = temp_dir.path().to_owned(); let fs_resolver = FileSystemResolver::with_stdlib_root(stdlib_root_path.to_str().unwrap()); - let result = fs_resolver.resolve(file_path.clone(), "bar.zok".into()); + let result = fs_resolver.resolve(file_path, "bar.zok".into()); assert!(result.is_ok()); // the imported file should be the user's assert_eq!(result.unwrap().0, String::from("\n")); diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 41af7fdd..2fb5aa88 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -14,7 +14,7 @@ use zokrates_core::proof_system::bellman::Bellman; use zokrates_core::proof_system::groth16::G16; use zokrates_core::proof_system::{Backend, Proof, Scheme, SolidityAbi, SolidityCompatibleScheme}; use zokrates_core::typed_absy::abi::Abi; -use zokrates_core::typed_absy::types::Signature; +use zokrates_core::typed_absy::types::ConcreteSignature as Signature; use zokrates_field::Bn128Field; #[derive(Serialize, Deserialize)] diff --git a/zokrates_parser/Cargo.toml b/zokrates_parser/Cargo.toml index e056c0d8..cc77f702 100644 --- a/zokrates_parser/Cargo.toml +++ b/zokrates_parser/Cargo.toml @@ -9,4 +9,4 @@ pest = "2.0" pest_derive = "2.0" [dev-dependencies] -glob = "0.2" \ No newline at end of file +glob = "0.2" diff --git a/zokrates_parser/src/zokrates.pest b/zokrates_parser/src/zokrates.pest index 1ecb76fb..94daad5e 100644 --- a/zokrates_parser/src/zokrates.pest +++ b/zokrates_parser/src/zokrates.pest @@ -8,8 +8,10 @@ import_directive = { main_import_directive | from_import_directive } from_import_directive = { "from" ~ "\"" ~ import_source ~ "\"" ~ "import" ~ identifier ~ ("as" ~ identifier)? ~ NEWLINE*} main_import_directive = {"import" ~ "\"" ~ import_source ~ "\"" ~ ("as" ~ identifier)? ~ NEWLINE+} import_source = @{(!"\"" ~ ANY)*} -function_definition = {"def" ~ identifier ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* } +function_definition = {"def" ~ identifier ~ constant_generics_declaration? ~ "(" ~ parameter_list ~ ")" ~ return_types ~ ":" ~ NEWLINE* ~ statement* } return_types = _{ ( "->" ~ ( "(" ~ type_list ~ ")" | ty ))? } +constant_generics_declaration = _{ "<" ~ constant_generics_list ~ ">" } +constant_generics_list = _{ identifier ~ ("," ~ identifier)* } parameter_list = _{(parameter ~ ("," ~ parameter)*)?} parameter = {vis? ~ ty ~ identifier} @@ -66,14 +68,19 @@ to_expression = { expression } conditional_expression = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"} -postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers. Not sure that's what we want. +postfix_expression = { identifier ~ access+ } // we force there to be at least one access, otherwise this matches single identifiers access = { array_access | call_access | member_access } array_access = { "[" ~ range_or_expression ~ "]" } -call_access = { "(" ~ expression_list ~ ")" } +call_access = { explicit_generics? ~ "(" ~ arguments ~ ")" } +arguments = { expression_list } +explicit_generics = { "::<" ~ constant_generics_values ~ ">" } +constant_generics_values = _{ constant_generics_value ~ ("," ~ constant_generics_value)* } +constant_generics_value = { literal | identifier | underscore } +underscore = { "_" } member_access = { "." ~ identifier } primary_expression = { identifier - | constant + | literal } inline_struct_expression = { identifier ~ "{" ~ NEWLINE* ~ inline_struct_member_list ~ NEWLINE* ~ "}" } @@ -84,21 +91,37 @@ inline_array_expression = { "[" ~ NEWLINE* ~ inline_array_inner ~ NEWLINE* ~ "]" inline_array_inner = _{(spread_or_expression ~ ("," ~ NEWLINE* ~ spread_or_expression)*)?} spread_or_expression = { spread | expression } range_or_expression = { range | expression } -array_initializer_expression = { "[" ~ expression ~ ";" ~ constant ~ "]" } + exponent_expression = { "(" ~ expression ~ ")" | primary_expression } +array_initializer_expression = { "[" ~ expression ~ ";" ~ expression ~ "]" } // End Expressions assignee = { identifier ~ assignee_access* } assignee_access = { array_access | member_access } identifier = @{ ((!keyword ~ ASCII_ALPHA) | (keyword ~ (ASCII_ALPHANUMERIC | "_"))) ~ (ASCII_ALPHANUMERIC | "_")* } -constant = { hex_number | decimal_number | boolean_literal } + +// Literals for all types + +literal = { hex_literal | decimal_literal | boolean_literal } + +decimal_literal = ${ decimal_number ~ decimal_suffix? } decimal_number = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* } +decimal_suffix = { decimal_suffix_u8 | decimal_suffix_u16 | decimal_suffix_u32 | decimal_suffix_field } +decimal_suffix_u8 = { "u8" } +decimal_suffix_u16 = { "u16" } +decimal_suffix_u32 = { "u32" } +decimal_suffix_field = { "f" } + boolean_literal = { "true" | "false" } -hex_number = _{ hex_number_32 | hex_number_16 | hex_number_8 } -hex_number_8 = @{ "0x" ~ ASCII_HEX_DIGIT{2} } -hex_number_16 = @{ "0x" ~ ASCII_HEX_DIGIT{4} } -hex_number_32 = @{ "0x" ~ ASCII_HEX_DIGIT{8} } + +hex_literal = !{ "0x" ~ hex_number } +hex_number = { hex_number_u32 | hex_number_u16 | hex_number_u8 } +hex_number_u8 = { ASCII_HEX_DIGIT{2} } +hex_number_u16 = { ASCII_HEX_DIGIT{4} } +hex_number_u32 = { ASCII_HEX_DIGIT{8} } + +// Operators op_or = @{"||"} op_and = @{"&&"} diff --git a/zokrates_pest_ast/src/lib.rs b/zokrates_pest_ast/src/lib.rs index 76c9c92b..edc4ba1c 100644 --- a/zokrates_pest_ast/src/lib.rs +++ b/zokrates_pest_ast/src/lib.rs @@ -8,14 +8,16 @@ use zokrates_parser::Rule; extern crate lazy_static; pub use ast::{ - Access, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, Assignee, - AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, CallAccess, - ConstantExpression, DecimalNumberExpression, DefinitionStatement, Expression, FieldType, File, - FromExpression, Function, IdentifierExpression, ImportDirective, ImportSource, + Access, Arguments, ArrayAccess, ArrayInitializerExpression, ArrayType, AssertionStatement, + Assignee, AssigneeAccess, BasicOrStructType, BasicType, BinaryExpression, BinaryOperator, + CallAccess, ConstantGenericValue, DecimalLiteralExpression, DecimalNumber, DecimalSuffix, + DefinitionStatement, ExplicitGenerics, Expression, FieldType, File, FromExpression, Function, + HexLiteralExpression, HexNumberExpression, IdentifierExpression, ImportDirective, ImportSource, InlineArrayExpression, InlineStructExpression, InlineStructMember, IterationStatement, - OptionallyTypedAssignee, Parameter, PostfixExpression, Range, RangeOrExpression, - ReturnStatement, Span, Spread, SpreadOrExpression, Statement, StructDefinition, StructField, - TernaryExpression, ToExpression, Type, UnaryExpression, UnaryOperator, Visibility, + LiteralExpression, OptionallyTypedAssignee, Parameter, PostfixExpression, Range, + RangeOrExpression, ReturnStatement, Span, Spread, SpreadOrExpression, Statement, + StructDefinition, StructField, TernaryExpression, ToExpression, Type, UnaryExpression, + UnaryOperator, Underscore, Visibility, }; mod ast { @@ -154,6 +156,7 @@ mod ast { #[pest_ast(rule(Rule::function_definition))] pub struct Function<'ast> { pub id: IdentifierExpression<'ast>, + pub generics: Vec>, pub parameters: Vec>, pub returns: Vec>, pub statements: Vec>, @@ -298,6 +301,7 @@ mod ast { #[pest_ast(rule(Rule::vis_private))] pub struct PrivateVisibility {} + #[allow(clippy::large_enum_variant)] #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::statement))] pub enum Statement<'ast> { @@ -374,7 +378,7 @@ mod ast { Unary(UnaryExpression<'ast>), Postfix(PostfixExpression<'ast>), Identifier(IdentifierExpression<'ast>), - Constant(ConstantExpression<'ast>), + Literal(LiteralExpression<'ast>), InlineArray(InlineArrayExpression<'ast>), InlineStruct(InlineStructExpression<'ast>), ArrayInitializer(ArrayInitializerExpression<'ast>), @@ -484,13 +488,13 @@ mod ast { #[pest_ast(rule(Rule::primary_expression))] pub enum PrimaryExpression<'ast> { Identifier(IdentifierExpression<'ast>), - Constant(ConstantExpression<'ast>), + Literal(LiteralExpression<'ast>), } impl<'ast> From> for Expression<'ast> { fn from(e: PrimaryExpression<'ast>) -> Self { match e { - PrimaryExpression::Constant(c) => Expression::Constant(c), + PrimaryExpression::Literal(c) => Expression::Literal(c), PrimaryExpression::Identifier(i) => Expression::Identifier(i), } } @@ -590,7 +594,7 @@ mod ast { #[pest_ast(rule(Rule::array_initializer_expression))] pub struct ArrayInitializerExpression<'ast> { pub value: Box>, - pub count: ConstantExpression<'ast>, + pub count: Box>, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -604,6 +608,7 @@ mod ast { pub span: Span<'ast>, } + #[allow(clippy::large_enum_variant)] #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::access))] pub enum Access<'ast> { @@ -622,6 +627,38 @@ mod ast { #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::call_access))] pub struct CallAccess<'ast> { + pub explicit_generics: Option>, + pub arguments: Arguments<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::explicit_generics))] + pub struct ExplicitGenerics<'ast> { + pub values: Vec>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::constant_generics_value))] + pub enum ConstantGenericValue<'ast> { + Value(LiteralExpression<'ast>), + Identifier(IdentifierExpression<'ast>), + Underscore(Underscore<'ast>), + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::underscore))] + pub struct Underscore<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::arguments))] + pub struct Arguments<'ast> { pub expressions: Vec>, #[pest_ast(outer())] pub span: Span<'ast>, @@ -701,7 +738,7 @@ mod ast { match self { Expression::Binary(b) => &b.span, Expression::Identifier(i) => &i.span, - Expression::Constant(c) => &c.span(), + Expression::Literal(c) => &c.span(), Expression::Ternary(t) => &t.span, Expression::Postfix(p) => &p.span, Expression::InlineArray(a) => &a.span, @@ -736,32 +773,72 @@ mod ast { } #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::constant))] - pub enum ConstantExpression<'ast> { - DecimalNumber(DecimalNumberExpression<'ast>), + #[pest_ast(rule(Rule::literal))] + pub enum LiteralExpression<'ast> { + DecimalLiteral(DecimalLiteralExpression<'ast>), BooleanLiteral(BooleanLiteralExpression<'ast>), - U8(U8NumberExpression<'ast>), - U16(U16NumberExpression<'ast>), - U32(U32NumberExpression<'ast>), + HexLiteral(HexLiteralExpression<'ast>), } - impl<'ast> ConstantExpression<'ast> { + impl<'ast> LiteralExpression<'ast> { pub fn span(&self) -> &Span<'ast> { match self { - ConstantExpression::DecimalNumber(n) => &n.span, - ConstantExpression::BooleanLiteral(c) => &c.span, - ConstantExpression::U8(c) => &c.span, - ConstantExpression::U16(c) => &c.span, - ConstantExpression::U32(c) => &c.span, + LiteralExpression::DecimalLiteral(n) => &n.span, + LiteralExpression::BooleanLiteral(c) => &c.span, + LiteralExpression::HexLiteral(h) => &h.span, } } } + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::decimal_suffix))] + pub enum DecimalSuffix<'ast> { + U8(U8Suffix<'ast>), + U16(U16Suffix<'ast>), + U32(U32Suffix<'ast>), + Field(FieldSuffix<'ast>), + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::decimal_suffix_u8))] + pub struct U8Suffix<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::decimal_suffix_u16))] + pub struct U16Suffix<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::decimal_suffix_u32))] + pub struct U32Suffix<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::decimal_suffix_field))] + pub struct FieldSuffix<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + #[derive(Debug, FromPest, PartialEq, Clone)] #[pest_ast(rule(Rule::decimal_number))] - pub struct DecimalNumberExpression<'ast> { - #[pest_ast(outer(with(span_into_str)))] - pub value: String, + pub struct DecimalNumber<'ast> { + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::decimal_literal))] + pub struct DecimalLiteralExpression<'ast> { + pub value: DecimalNumber<'ast>, + pub suffix: Option>, #[pest_ast(outer())] pub span: Span<'ast>, } @@ -776,7 +853,23 @@ mod ast { } #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::hex_number_8))] + #[pest_ast(rule(Rule::hex_literal))] + pub struct HexLiteralExpression<'ast> { + pub value: HexNumberExpression<'ast>, + #[pest_ast(outer())] + pub span: Span<'ast>, + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::hex_number))] + pub enum HexNumberExpression<'ast> { + U8(U8NumberExpression<'ast>), + U16(U16NumberExpression<'ast>), + U32(U32NumberExpression<'ast>), + } + + #[derive(Debug, FromPest, PartialEq, Clone)] + #[pest_ast(rule(Rule::hex_number_u8))] pub struct U8NumberExpression<'ast> { #[pest_ast(outer(with(span_into_str)))] pub value: String, @@ -785,7 +878,7 @@ mod ast { } #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::hex_number_16))] + #[pest_ast(rule(Rule::hex_number_u16))] pub struct U16NumberExpression<'ast> { #[pest_ast(outer(with(span_into_str)))] pub value: String, @@ -794,7 +887,7 @@ mod ast { } #[derive(Debug, FromPest, PartialEq, Clone)] - #[pest_ast(rule(Rule::hex_number_32))] + #[pest_ast(rule(Rule::hex_number_u32))] pub struct U32NumberExpression<'ast> { #[pest_ast(outer(with(span_into_str)))] pub value: String, @@ -847,7 +940,7 @@ impl fmt::Display for Error { } pub fn generate_ast(input: &str) -> Result { - let parse_tree = parse(input).map_err(|e| Error(e))?; + let parse_tree = parse(input).map_err(Error)?; Ok(Prog::from(parse_tree).0) } @@ -920,6 +1013,7 @@ mod tests { pragma: None, structs: vec![], functions: vec![Function { + generics: vec![], id: IdentifierExpression { value: String::from("main"), span: Span::new(&source, 33, 37).unwrap() @@ -930,15 +1024,21 @@ mod tests { }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::add( - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("1"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + value: DecimalNumber { + span: Span::new(&source, 59, 60).unwrap() + }, + suffix: None, span: Span::new(&source, 59, 60).unwrap() } )), - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("1"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + value: DecimalNumber { + span: Span::new(&source, 63, 64).unwrap() + }, + suffix: None, span: Span::new(&source, 63, 64).unwrap() } )), @@ -973,6 +1073,7 @@ mod tests { pragma: None, structs: vec![], functions: vec![Function { + generics: vec![], id: IdentifierExpression { value: String::from("main"), span: Span::new(&source, 33, 37).unwrap() @@ -983,29 +1084,41 @@ mod tests { }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::add( - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("1"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 59, 60).unwrap() + }, span: Span::new(&source, 59, 60).unwrap() } )), Expression::mul( - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("2"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 63, 64).unwrap() + }, span: Span::new(&source, 63, 64).unwrap() } )), Expression::pow( - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("3"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 67, 68).unwrap() + }, span: Span::new(&source, 67, 68).unwrap() } )), - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("4"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 72, 73).unwrap() + }, span: Span::new(&source, 72, 73).unwrap() } )), @@ -1044,6 +1157,7 @@ mod tests { pragma: None, structs: vec![], functions: vec![Function { + generics: vec![], id: IdentifierExpression { value: String::from("main"), span: Span::new(&source, 33, 37).unwrap() @@ -1054,21 +1168,30 @@ mod tests { }))], statements: vec![Statement::Return(ReturnStatement { expressions: vec![Expression::if_else( - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("1"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 62, 63).unwrap() + }, span: Span::new(&source, 62, 63).unwrap() } )), - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("2"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 69, 70).unwrap() + }, span: Span::new(&source, 69, 70).unwrap() } )), - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("3"), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 76, 77).unwrap() + }, span: Span::new(&source, 76, 77).unwrap() } )), @@ -1102,6 +1225,7 @@ mod tests { pragma: None, structs: vec![], functions: vec![Function { + generics: vec![], id: IdentifierExpression { value: String::from("main"), span: Span::new(&source, 4, 8).unwrap() @@ -1111,9 +1235,12 @@ mod tests { span: Span::new(&source, 15, 20).unwrap() }))], statements: vec![Statement::Return(ReturnStatement { - expressions: vec![Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("1"), + expressions: vec![Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 31, 32).unwrap() + }, span: Span::new(&source, 31, 32).unwrap() } ))], @@ -1138,6 +1265,7 @@ mod tests { pragma: None, structs: vec![], functions: vec![Function { + generics: vec![], id: IdentifierExpression { value: String::from("main"), span: Span::new(&source, 4, 8).unwrap() @@ -1181,29 +1309,42 @@ mod tests { span: Span::new(&source, 36, 39).unwrap() }, accesses: vec![Access::Call(CallAccess { - expressions: vec![ - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("1"), - span: Span::new(&source, 40, 41).unwrap() - } - )), - Expression::add( - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("2"), - span: Span::new(&source, 43, 44).unwrap() + explicit_generics: None, + arguments: Arguments { + expressions: vec![ + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 40, 41).unwrap() + }, + span: Span::new(&source, 40, 41).unwrap() } )), - Expression::Constant(ConstantExpression::DecimalNumber( - DecimalNumberExpression { - value: String::from("3"), - span: Span::new(&source, 47, 48).unwrap() - } - )), - Span::new(&source, 43, 48).unwrap() - ), - ], + Expression::add( + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 43, 44).unwrap() + }, + span: Span::new(&source, 43, 44).unwrap() + } + )), + Expression::Literal(LiteralExpression::DecimalLiteral( + DecimalLiteralExpression { + suffix: None, + value: DecimalNumber { + span: Span::new(&source, 47, 48).unwrap() + }, + span: Span::new(&source, 47, 48).unwrap() + } + )), + Span::new(&source, 43, 48).unwrap() + ), + ], + span: Span::new(&source, 40, 48).unwrap() + }, span: Span::new(&source, 39, 49).unwrap() })], span: Span::new(&source, 36, 49).unwrap(), @@ -1221,16 +1362,16 @@ mod tests { #[test] fn playground() { - let source = r#"import "heyman" as yo + let source = r#"import "foo" as bar struct Foo { field[2] foo Bar bar } - def main(private field[23] a) -> (bool[234 + 6]): + def main

(private field[Q] a) -> (bool[234 + 6]): field a = 1 - a[32 + x][55] = y + a[32 + x][55] = foo::(y) for field i in 0..3 do assert(a == 1 + 2 + 3+ 4+ 5+ 6+ 6+ 7+ 8 + 4+ 5+ 3+ 4+ 2+ 3) endfor diff --git a/zokrates_stdlib/build.rs b/zokrates_stdlib/build.rs index 5081027d..cb2dd02c 100644 --- a/zokrates_stdlib/build.rs +++ b/zokrates_stdlib/build.rs @@ -15,5 +15,5 @@ fn export_stdlib() { let out_dir = env::var("OUT_DIR").unwrap(); let mut options = CopyOptions::new(); options.overwrite = true; - copy_items(&vec!["stdlib"], out_dir, &options).unwrap(); + copy_items(&["stdlib"], out_dir, &options).unwrap(); } diff --git a/zokrates_stdlib/stdlib/ecc/babyjubjubParams.zok b/zokrates_stdlib/stdlib/ecc/babyjubjubParams.zok index fcc433b1..11fd783a 100644 --- a/zokrates_stdlib/stdlib/ecc/babyjubjubParams.zok +++ b/zokrates_stdlib/stdlib/ecc/babyjubjubParams.zok @@ -4,7 +4,6 @@ // Note: parameters will be updated soon to be more compatible with zCash's implementation struct BabyJubJubParams { - // field JUBJUBE field JUBJUBC field JUBJUBA field JUBJUBD @@ -17,8 +16,7 @@ struct BabyJubJubParams { def main() -> BabyJubJubParams: -// Order of the curve E - // field JUBJUBE = 21888242871839275222246405745257275088614511777268538073601725287587578984328 + // Order of the curve for reference: 21888242871839275222246405745257275088614511777268538073601725287587578984328 field JUBJUBC = 8 // Cofactor field JUBJUBA = 168700 // Coefficient A field JUBJUBD = 168696 // Coefficient D @@ -40,7 +38,6 @@ return BabyJubJubParams { INFINITY: INFINITY, Gu: Gu, Gv: Gv, - // JUBJUBE: JUBJUBE, JUBJUBC: JUBJUBC, MONTA: MONTA, MONTB: MONTB diff --git a/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zok b/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zok index dd7b4300..7e90625c 100644 --- a/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zok +++ b/zokrates_stdlib/stdlib/ecc/edwardsScalarMult.zok @@ -15,8 +15,8 @@ def main(bool[256] exponent, field[2] pt, BabyJubJubParams context) -> field[2]: field[2] doubledP = pt field[2] accumulatedP = infinity - for field i in 0..256 do - field j = 255 - i + for u32 i in 0..256 do + u32 j = 255 - i field[2] candidateP = add(accumulatedP, doubledP, context) accumulatedP = if exponent[j] then candidateP else accumulatedP fi doubledP = add(doubledP, doubledP, context) diff --git a/zokrates_stdlib/stdlib/hashes/blake2/blake2s.zok b/zokrates_stdlib/stdlib/hashes/blake2/blake2s.zok new file mode 100644 index 00000000..8bbf125f --- /dev/null +++ b/zokrates_stdlib/stdlib/hashes/blake2/blake2s.zok @@ -0,0 +1,4 @@ +import "hashes/blake2/blake2s_p" as blake2s_p + +def main(u32[K][16] input) -> (u32[8]): + return blake2s_p(input, [0; 2]) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok b/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok new file mode 100644 index 00000000..239f9fce --- /dev/null +++ b/zokrates_stdlib/stdlib/hashes/blake2/blake2s_p.zok @@ -0,0 +1,98 @@ +// https://tools.ietf.org/html/rfc7693 + +import "EMBED/u32_to_bits" as to_bits +import "EMBED/u32_from_bits" as from_bits + +def right_rotate(u32 e) -> u32: + bool[32] b = to_bits(e) + return from_bits([...b[32 - N..], ...b[..32 - N]]) + +def blake2s_iv() -> (u32[8]): + return [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, + 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19 + ] + +def blake2s_sigma() -> (u32[10][16]): + return [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0] + ] + +def mixing_g(u32[16] v, u32 a, u32 b, u32 c, u32 d, u32 x, u32 y) -> (u32[16]): + v[a] = (v[a] + v[b] + x) + v[d] = right_rotate::<16>(v[d] ^ v[a]) + v[c] = (v[c] + v[d]) + v[b] = right_rotate::<12>(v[b] ^ v[c]) + v[a] = (v[a] + v[b] + y) + v[d] = right_rotate::<8>(v[d] ^ v[a]) + v[c] = (v[c] + v[d]) + v[b] = right_rotate::<7>(v[b] ^ v[c]) + return v + +def blake2s_compression(u32[8] h, u32[16] m, u32[2] t, bool last) -> (u32[8]): + u32[16] v = [...h, ...blake2s_iv()] + + v[12] = v[12] ^ t[0] + v[13] = v[13] ^ t[1] + v[14] = if last then v[14] ^ 0xFFFFFFFF else v[14] fi + + u32[10][16] sigma = blake2s_sigma() + + for u32 i in 0..10 do + u32[16] s = sigma[i] + v = mixing_g(v, 0, 4, 8, 12, m[s[0]], m[s[1]]) + v = mixing_g(v, 1, 5, 9, 13, m[s[2]], m[s[3]]) + v = mixing_g(v, 2, 6, 10, 14, m[s[4]], m[s[5]]) + v = mixing_g(v, 3, 7, 11, 15, m[s[6]], m[s[7]]) + v = mixing_g(v, 0, 5, 10, 15, m[s[8]], m[s[9]]) + v = mixing_g(v, 1, 6, 11, 12, m[s[10]], m[s[11]]) + v = mixing_g(v, 2, 7, 8, 13, m[s[12]], m[s[13]]) + v = mixing_g(v, 3, 4, 9, 14, m[s[14]], m[s[15]]) + endfor + + for u32 i in 0..8 do + h[i] = h[i] ^ v[i] ^ v[i + 8] + endfor + + return h + +def blake2s_init(u32[2] p) -> (u32[8]): + u32[8] iv = blake2s_iv() + u32[8] h = [ + iv[0] ^ 0x01010000 ^ 0x00000020, + iv[1], + iv[2], + iv[3], + iv[4], + iv[5], + iv[6] ^ p[0], + iv[7] ^ p[1] + ] + return h + +def main(u32[K][16] input, u32[2] p) -> (u32[8]): + u32[8] h = blake2s_init(p) + + u32 t0 = 0 + u32 t1 = 0 + + for u32 i in 0..K-1 do + t0 = (i + 1) * 64 + t1 = if t0 == 0 then t1 + 1 else t1 fi + h = blake2s_compression(h, input[i], [t0, t1], false) + endfor + + t0 = t0 + 64 + t1 = if t0 == 0 then t1 + 1 else t1 fi + + h = blake2s_compression(h, input[K - 1], [t0, t1], true) + return h \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7.zok b/zokrates_stdlib/stdlib/hashes/mimc7/mimc7.zok new file mode 100644 index 00000000..525d3147 --- /dev/null +++ b/zokrates_stdlib/stdlib/hashes/mimc7/mimc7.zok @@ -0,0 +1,18 @@ +import "./constants" as constants + +def main(field x_in, field k) -> field: + field[91] c = constants() + field t = 0 + field[ROUNDS] t2 = [0; ROUNDS] + field[ROUNDS] t4 = [0; ROUNDS] + field[ROUNDS] t6 = [0; ROUNDS] + field[ROUNDS] t7 = [0; ROUNDS] // we define t7 length +1 to reference implementation as ZoKrates wont allow conditional branching. -> out of bounds array error + for u32 i in 0..ROUNDS do + u32 i2 = if i == 0 then 0 else i - 1 fi + t = if i == 0 then k+x_in else k + t7[i2] + c[i] fi + t2[i] = t*t + t4[i] = t2[i]*t2[i] + t6[i] = t4[i]*t2[i] + t7[i] = t6[i]*t + endfor + return t6[ROUNDS - 1]*t + k \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R10.zok b/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R10.zok deleted file mode 100644 index d3a58e6b..00000000 --- a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R10.zok +++ /dev/null @@ -1,18 +0,0 @@ -import "./constants" as constants - -def main(field x_in, field k) -> field: - field[91] c = constants() - field t = 0 - field[10] t2 = [0; 10] - field[10] t4 = [0; 10] - field[10] t6 = [0; 10] - field[10] t7 = [0; 10] // we define t7 length +1 to reference implementation as ZoKrates wont allow conditional branching. -> out of bounds array error - for field i in 0..10 do - field i2 = if i == 0 then 0 else i - 1 fi - t = if i == 0 then k+x_in else k + t7[i2] + c[i] fi - t2[i] = t*t - t4[i] = t2[i]*t2[i] - t6[i] = t4[i]*t2[i] - t7[i] = t6[i]*t - endfor - return t6[9]*t + k \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok b/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok deleted file mode 100644 index f231662f..00000000 --- a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R20.zok +++ /dev/null @@ -1,18 +0,0 @@ -import "./constants" as constants - -def main(field x_in, field k) -> field: - field[91] c = constants() - field t = 0 - field[20] t2 = [0; 20] - field[20] t4 = [0; 20] - field[20] t6 = [0; 20] - field[20] t7 = [0; 20] // we define t7 length +1 to reference implementation as ZoKrates wont allow conditional branching. -> out of bounds array error - for field i in 0..20 do - field i2 = if i == 0 then 0 else i - 1 fi - t = if i == 0 then k+x_in else k + t7[i2] + c[i] fi - t2[i] = t*t - t4[i] = t2[i]*t2[i] - t6[i] = t4[i]*t2[i] - t7[i] = t6[i]*t - endfor - return t6[19]*t + k \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R50.zok b/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R50.zok deleted file mode 100644 index c3dc3d1e..00000000 --- a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R50.zok +++ /dev/null @@ -1,18 +0,0 @@ -import "./constants" as constants - -def main(field x_in, field k) -> field: - field[91] c = constants() - field t = 0 - field[50] t2 = [0; 50] - field[50] t4 = [0; 50] - field[50] t6 = [0; 50] - field[50] t7 = [0; 50] // we define t7 length +1 to reference implementation as ZoKrates wont allow conditional branching. - for field i in 0..50 do - field i2 = if i == 0 then 0 else i - 1 fi - t = if i == 0 then k+x_in else k + t7[i2] + c[i] fi - t2[i] = t*t - t4[i] = t2[i]*t2[i] - t6[i] = t4[i]*t2[i] - t7[i] = t6[i]*t - endfor - return t6[49]*t + k \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R90.zok b/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R90.zok deleted file mode 100644 index 83698aae..00000000 --- a/zokrates_stdlib/stdlib/hashes/mimc7/mimc7R90.zok +++ /dev/null @@ -1,18 +0,0 @@ -import "./constants" as constants - -def main(field x_in, field k) -> field: - field[91] c = constants() - field t = 0 - field[90] t2 = [0; 90] - field[90] t4 = [0; 90] - field[90] t6 = [0; 90] - field[90] t7 = [0; 90] // we define t7 length +1 to reference implementation as ZoKrates wont allow conditional branching. - for field i in 0..90 do - field i2 = if i == 0 then 0 else i - 1 fi - t = if i == 0 then k+x_in else k + t7[i2] + c[i] fi - t2[i] = t*t - t4[i] = t2[i]*t2[i] - t6[i] = t4[i]*t2[i] - t7[i] = t6[i]*t - endfor - return t6[89]*t + k \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcFeistel.zok b/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcFeistel.zok index 7257c39a..6e23e44f 100644 --- a/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcFeistel.zok +++ b/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcFeistel.zok @@ -5,15 +5,15 @@ def main(field xL_in, field xR_in, field k) -> field[2]: field[220] IV = IVconstants() field t = 0 - field nRounds = 220 + u32 nRounds = 220 field[220] t2 = [0; 220] //length: nRounds field[220] t4 = [0; 220] //... field[220] xL = [0; 220] //... field[220] xR = [0; 220] //... field c = 0 - for field i in 0..nRounds do - field idx = if i == 0 then 0 else i - 1 fi + for u32 i in 0..nRounds do + u32 idx = if i == 0 then 0 else i - 1 fi c = IV[i] t = if i == 0 then k + xL_in else k + xL[idx] + c fi diff --git a/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcSponge.zok b/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcSponge.zok index 033617ad..e58c1e37 100644 --- a/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcSponge.zok +++ b/zokrates_stdlib/stdlib/hashes/mimcSponge/mimcSponge.zok @@ -2,18 +2,18 @@ import "./mimcFeistel" as MiMCFeistel def main(field[2] ins, field k) -> field[3]: //nInputs = 2, nOutputs = 3, - field nInputs = 2 - field nOutputs = 3 + u32 nInputs = 2 + u32 nOutputs = 3 field[4][2] S = [[0; 2]; 4] // Dim: (nInputs + nOutputs - 1, 2) field[3] outs = [0; 3] - for field i in 0..nInputs do - field idx = if i == 0 then 0 else i - 1 fi + for u32 i in 0..nInputs do + u32 idx = if i == 0 then 0 else i - 1 fi S[i] = if i == 0 then MiMCFeistel(ins[0], 0, k) else MiMCFeistel(S[idx][0] + ins[i], S[idx][1], k) fi endfor outs[0] = S[nInputs - 1][0] - for field i in 0..(nOutputs - 1) do + for u32 i in 0..(nOutputs - 1) do field[2] feistelRes = MiMCFeistel(S[nInputs + i - 1][0], S[nInputs + i - 1][1], k) S[nInputs + i] = feistelRes outs[i + 1] = S[nInputs + i][0] diff --git a/zokrates_stdlib/stdlib/hashes/sha256/1024bit.zok b/zokrates_stdlib/stdlib/hashes/sha256/1024bit.zok index 6f52a95a..ff5b298b 100644 --- a/zokrates_stdlib/stdlib/hashes/sha256/1024bit.zok +++ b/zokrates_stdlib/stdlib/hashes/sha256/1024bit.zok @@ -1,14 +1,9 @@ -import "./IVconstants" as IVconstants -import "./shaRound" as sha256 +import "./sha256" as sha256 // A function that takes 4 u32[8] arrays as inputs, concatenates them, // and returns their sha256 compression as a u32[8]. // Note: no padding is applied def main(u32[8] a, u32[8] b, u32[8] c, u32[8] d) -> u32[8]: - - u32[8] IV = IVconstants() - u32[8] digest1 = sha256([...a, ...b], IV) - u32[8] digest2 = sha256([...c, ...d], digest1) - - return digest2 \ No newline at end of file + u32[8] res = sha256([[...a, ...b], [...c, ...d]]) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/sha256/1536bit.zok b/zokrates_stdlib/stdlib/hashes/sha256/1536bit.zok index f1b8860c..12895aad 100644 --- a/zokrates_stdlib/stdlib/hashes/sha256/1536bit.zok +++ b/zokrates_stdlib/stdlib/hashes/sha256/1536bit.zok @@ -1,15 +1,9 @@ -import "./IVconstants" as IVconstants -import "./shaRound" as sha256 +import "./sha256" as sha256 // A function that takes 6 u32[8] arrays as inputs, concatenates them, // and returns their sha256 compression as a u32[8]. // Note: no padding is applied def main(u32[8] a, u32[8] b, u32[8] c, u32[8] d, u32[8] e, u32[8] f) -> u32[8]: - - u32[8] IV = IVconstants() - u32[8] digest1 = sha256([...a, ...b], IV) - u32[8] digest2 = sha256([...c, ...d], digest1) - u32[8] digest3 = sha256([...e, ...f], digest2) - - return digest3 \ No newline at end of file + u32[8] res = sha256([[...a, ...b], [...c, ...d], [...e, ...f]]) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/hashes/sha256/512bit.zok b/zokrates_stdlib/stdlib/hashes/sha256/512bit.zok index e7772a13..4d1c1b60 100644 --- a/zokrates_stdlib/stdlib/hashes/sha256/512bit.zok +++ b/zokrates_stdlib/stdlib/hashes/sha256/512bit.zok @@ -1,5 +1,4 @@ -import "./IVconstants" as IVconstants -import "./shaRound" as sha256 +import "./sha256" as sha256 // A function that takes 2 u32[8] arrays as inputs, concatenates them, // and returns their sha256 compression as a u32[8]. @@ -7,4 +6,4 @@ import "./shaRound" as sha256 def main(u32[8] a, u32[8] b) -> u32[8]: - return sha256([...a, ...b], IVconstants()) + return sha256([[...a, ...b]]) diff --git a/zokrates_stdlib/stdlib/hashes/sha256/sha256.zok b/zokrates_stdlib/stdlib/hashes/sha256/sha256.zok new file mode 100644 index 00000000..cf067e48 --- /dev/null +++ b/zokrates_stdlib/stdlib/hashes/sha256/sha256.zok @@ -0,0 +1,15 @@ +import "./IVconstants" as IVconstants +import "./shaRound" as shaRound + +// A function that takes K u32[8] arrays as inputs, concatenates them, +// and returns their sha256 compression as a u32[8]. +// Note: no padding is applied + +def main(u32[K][16] a) -> u32[8]: + u32[8] current = IVconstants() + + for u32 i in 0..K do + current = shaRound(a[i], current) + endfor + + return current diff --git a/zokrates_stdlib/stdlib/hashes/sha256/shaRound.zok b/zokrates_stdlib/stdlib/hashes/sha256/shaRound.zok index 0b332eb3..b1248f6d 100644 --- a/zokrates_stdlib/stdlib/hashes/sha256/shaRound.zok +++ b/zokrates_stdlib/stdlib/hashes/sha256/shaRound.zok @@ -42,9 +42,9 @@ def right_rotate_25(u32 e) -> u32: bool[32] b = to_bits(e) return from_bits([...b[7..], ...b[..7]]) -def extend(u32[64] w, field i) -> u32: - u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 0x00000003) - u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 0x0000000a) +def extend(u32[64] w, u32 i) -> u32: + u32 s0 = right_rotate_7(w[i-15]) ^ right_rotate_18(w[i-15]) ^ (w[i-15] >> 3) + u32 s1 = right_rotate_17(w[i-2]) ^ right_rotate_19(w[i-2]) ^ (w[i-2] >> 10) return w[i-16] + s0 + w[i-7] + s1 def temp1(u32 e, u32 f, u32 g, u32 h, u32 k, u32 w) -> u32: @@ -84,7 +84,7 @@ def main(u32[16] input, u32[8] current) -> u32[8]: u32[64] w = [...input, ...[0x00000000; 48]] - for field i in 16..64 do + for u32 i in 16..64 do w[i] = extend(w, i) endfor @@ -97,7 +97,7 @@ def main(u32[16] input, u32[8] current) -> u32[8]: u32 g = h6 u32 h = h7 - for field i in 0..64 do + for u32 i in 0..64 do u32 t1 = temp1(e, f, g, h, k[i], w[i]) diff --git a/zokrates_stdlib/stdlib/utils/casts/bool_128_to_u32_4.zok b/zokrates_stdlib/stdlib/utils/casts/bool_128_to_u32_4.zok index 222835bf..75481860 100644 --- a/zokrates_stdlib/stdlib/utils/casts/bool_128_to_u32_4.zok +++ b/zokrates_stdlib/stdlib/utils/casts/bool_128_to_u32_4.zok @@ -1,4 +1,5 @@ -import "EMBED/u32_from_bits" as from_bits +import "./bool_array_to_u32_array" as bool_to_u32 def main(bool[128] bits) -> u32[4]: - return [from_bits(bits[0..32]), from_bits(bits[32..64]), from_bits(bits[64..96]), from_bits(bits[96..128])] + u32[4] res = bool_to_u32(bits) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/bool_256_to_u32_8.zok b/zokrates_stdlib/stdlib/utils/casts/bool_256_to_u32_8.zok index 538cde8b..6bf0cef9 100644 --- a/zokrates_stdlib/stdlib/utils/casts/bool_256_to_u32_8.zok +++ b/zokrates_stdlib/stdlib/utils/casts/bool_256_to_u32_8.zok @@ -1,4 +1,5 @@ -import "EMBED/u32_from_bits" as from_bits +import "./bool_array_to_u32_array" as bool_to_u32 def main(bool[256] bits) -> u32[8]: - return [from_bits(bits[0..32]), from_bits(bits[32..64]), from_bits(bits[64..96]), from_bits(bits[96..128]), from_bits(bits[128..160]), from_bits(bits[160..192]), from_bits(bits[192..224]), from_bits(bits[224..256])] + u32[8] res = bool_to_u32(bits) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok b/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok new file mode 100644 index 00000000..ba1f2693 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/casts/bool_array_to_u32_array.zok @@ -0,0 +1,15 @@ +import "EMBED/u32_from_bits" as from_bits + +// convert an array of bool to an array of u32 +// the sizes must match (one u32 for 32 bool) otherwise an error will happen +def main(bool[N] bits) -> u32[P]: + + assert(N == 32 * P) + + u32[P] res = [0; P] + + for u32 i in 0..P do + res[i] = from_bits(bits[32 * i..32 * (i + 1)]) + endfor + + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok index a94dd224..ec4faf53 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u16_to_field.zok @@ -3,8 +3,8 @@ import "EMBED/u16_to_bits" as to_bits def main(u16 i) -> field: bool[16] bits = to_bits(i) field res = 0 - for field j in 0..16 do - field exponent = 16 - j - 1 + for u32 j in 0..16 do + u32 exponent = 16 - j - 1 res = res + if bits[j] then 2 ** exponent else 0 fi endfor return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_4_to_bool_128.zok b/zokrates_stdlib/stdlib/utils/casts/u32_4_to_bool_128.zok index 35ad6188..f6bc7976 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_4_to_bool_128.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_4_to_bool_128.zok @@ -1,4 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits +import "./u32_array_to_bool_array" as to_bool_array def main(u32[4] input) -> bool[128]: - return [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3])] + bool[128] res = to_bool_array(input) + return res diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_8_to_bool_256.zok b/zokrates_stdlib/stdlib/utils/casts/u32_8_to_bool_256.zok index 84564fa7..6b08f2cf 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_8_to_bool_256.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_8_to_bool_256.zok @@ -1,4 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits +import "./u32_array_to_bool_array" as to_bool_array def main(u32[8] input) -> bool[256]: - return [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3]), ...to_bits(input[4]), ...to_bits(input[5]), ...to_bits(input[6]), ...to_bits(input[7])] + bool[256] res = to_bool_array(input) + return res diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok b/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok new file mode 100644 index 00000000..28c3d65d --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/casts/u32_array_to_bool_array.zok @@ -0,0 +1,15 @@ +import "EMBED/u32_to_bits" as to_bits + +def main(u32[N] input) -> bool[P]: + assert(P == 32 * N) + + bool[P] res = [false; P] + + for u32 i in 0..N do + bool[32] bits = to_bits(input[i]) + for u32 j in 0..32 do + res[i * 32 + j] = bits[j] + endfor + endfor + + return res diff --git a/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok index 6475a85c..a442a3bc 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u32_to_field.zok @@ -3,8 +3,8 @@ import "EMBED/u32_to_bits" as to_bits def main(u32 i) -> field: bool[32] bits = to_bits(i) field res = 0 - for field j in 0..32 do - field exponent = 32 - j - 1 + for u32 j in 0..32 do + u32 exponent = 32 - j - 1 res = res + if bits[j] then 2 ** exponent else 0 fi endfor return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok b/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok index 30a721c6..5de0047e 100644 --- a/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok +++ b/zokrates_stdlib/stdlib/utils/casts/u8_to_field.zok @@ -3,8 +3,8 @@ import "EMBED/u8_to_bits" as to_bits def main(u8 i) -> field: bool[8] bits = to_bits(i) field res = 0 - for field j in 0..8 do - field exponent = 8 - j - 1 + for u32 j in 0..8 do + u32 exponent = 8 - j - 1 res = res + if bits[j] then 2 ** exponent else 0 fi endfor return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/pack.zok b/zokrates_stdlib/stdlib/utils/pack/bool/pack.zok new file mode 100644 index 00000000..cbf853b9 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/pack.zok @@ -0,0 +1,10 @@ +def main(bool[N] bits) -> field: + + field out = 0 + + for u32 j in 0..N do + u32 i = N - (j + 1) + out = out + if bits[i] then (2 ** j) else 0 fi + endfor + + return out diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/pack128.zok b/zokrates_stdlib/stdlib/utils/pack/bool/pack128.zok index e69ac943..63962151 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/pack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/pack128.zok @@ -1,15 +1,7 @@ #pragma curve bn128 +import "./pack" as pack + // pack 128 big-endian bits into one field element def main(bool[128] bits) -> field: - - field out = 0 - - field len = 128 - - for field j in 0..len do - field i = len - (j + 1) - out = out + if bits[i] then (2 ** j) else 0 fi - endfor - - return out \ No newline at end of file + return pack(bits) \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/pack256.zok b/zokrates_stdlib/stdlib/utils/pack/bool/pack256.zok index d92ac2a3..11f3e9b3 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/pack256.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/pack256.zok @@ -1,17 +1,9 @@ #pragma curve bn128 +import "./pack" as pack + // pack 256 big-endian bits into one field element // Note: This is not a injective operation as `p` is smaller than `2**256 - 1` for bn128 // For example, `[0, 0,..., 0]` and `bits(p)` both point to `0` def main(bool[256] bits) -> field: - - field out = 0 - - field len = 256 - - for field j in 0..len do - field i = len - (j + 1) - out = out + if bits[i] then (2 ** j) else 0 fi - endfor - - return out + return pack(bits) diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok new file mode 100644 index 00000000..38bd04a7 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack.zok @@ -0,0 +1,12 @@ +#pragma curve bn128 + +import "EMBED/unpack" as unpack + +// Unpack a field element as N big endian bits +def main(field i) -> bool[N]: + + assert(N <= 254) + + bool[N] res = unpack(i) + + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok index 66d89557..a24a244b 100644 --- a/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/bool/unpack128.zok @@ -1,13 +1,9 @@ #pragma curve bn128 -import "EMBED/unpack" as unpack +import "./unpack" as unpack // Unpack a field element as 128 big-endian bits // Precondition: the input is smaller or equal to `2**128 - 1` def main(field i) -> bool[128]: - - bool[254] b = unpack(i) - - assert(b[0..126] == [false; 126]) - - return b[126..254] \ No newline at end of file + bool[128] res = unpack::<128>(i) + return res \ No newline at end of file diff --git a/zokrates_stdlib/stdlib/utils/pack/u32/pack.zok b/zokrates_stdlib/stdlib/utils/pack/u32/pack.zok new file mode 100644 index 00000000..da8d1c68 --- /dev/null +++ b/zokrates_stdlib/stdlib/utils/pack/u32/pack.zok @@ -0,0 +1,9 @@ +import "../../casts/u32_array_to_bool_array" as to_bits +import "../bool/pack" + +// pack N big-endian bits into one field element +def main(u32[N] input) -> field: + + bool[N * 32] bits = to_bits(input) + + return pack(bits) diff --git a/zokrates_stdlib/stdlib/utils/pack/u32/pack128.zok b/zokrates_stdlib/stdlib/utils/pack/u32/pack128.zok index 26e7c28e..d9ec24b1 100644 --- a/zokrates_stdlib/stdlib/utils/pack/u32/pack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/u32/pack128.zok @@ -1,9 +1,5 @@ -import "EMBED/u32_to_bits" as to_bits -import "../bool/pack128" +import "./pack" as pack // pack 128 big-endian bits into one field element def main(u32[4] input) -> field: - - bool[128] bits = [...to_bits(input[0]), ...to_bits(input[1]), ...to_bits(input[2]), ...to_bits(input[3])] - - return pack128(bits) + return pack(input) diff --git a/zokrates_stdlib/stdlib/utils/pack/u32/unpack128.zok b/zokrates_stdlib/stdlib/utils/pack/u32/unpack128.zok index ebdb7544..24eeb83a 100644 --- a/zokrates_stdlib/stdlib/utils/pack/u32/unpack128.zok +++ b/zokrates_stdlib/stdlib/utils/pack/u32/unpack128.zok @@ -6,5 +6,4 @@ import "../../casts/bool_128_to_u32_4" as from_bits // Unpack a field element as 128 big-endian bits // Precondition: the input is smaller or equal to `2**128 - 1` def main(field i) -> u32[4]: - return from_bits(unpack(i)) \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1024bit.json b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1024bit.json new file mode 100644 index 00000000..eb0fe602 --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1024bit.json @@ -0,0 +1,15 @@ +{ + "entry_point": "./tests/tests/hashes/blake2/blake2s_1024bit.zok", + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1024bit.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1024bit.zok new file mode 100644 index 00000000..899f120f --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1024bit.zok @@ -0,0 +1,6 @@ +import "hashes/blake2/blake2s" + +def main(): + u32[8] h = blake2s::<2>([[0; 16]; 2]) + assert(h == [0x2005424E, 0x7BCE81B9, 0x2CCEF4DB, 0x94DBBA4D, 0x7D9B0750, 0xB53797EB, 0xD3572923, 0xCB01F823]) + return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.json b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.json new file mode 100644 index 00000000..e637ac91 --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.json @@ -0,0 +1,15 @@ +{ + "entry_point": "./tests/tests/hashes/blake2/blake2s_1536bit.zok", + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok new file mode 100644 index 00000000..28ba1529 --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_1536bit.zok @@ -0,0 +1,6 @@ +import "hashes/blake2/blake2s" + +def main(): + u32[8] h = blake2s::<3>([[0x42424242; 16]; 3]) + assert(h == [0x804BD0E6, 0x90AD426E, 0x6BCF0BAD, 0xCB2D22C1, 0xF717B3C3, 0x4D9CB47F, 0xEB541A97, 0x061D9ED0]) + return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.json b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.json new file mode 100644 index 00000000..756f2038 --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.json @@ -0,0 +1,15 @@ +{ + "entry_point": "./tests/tests/hashes/blake2/blake2s_512bit.zok", + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok new file mode 100644 index 00000000..28d5edca --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_512bit.zok @@ -0,0 +1,6 @@ +import "hashes/blake2/blake2s" + +def main(): + u32[8] h = blake2s::<1>([[0; 16]]) + assert(h == [0x7CDB09AE, 0xB4424FD5, 0xB609EF90, 0xF61A54BC, 0x9B95E488, 0x353FC5B8, 0xE3566F9A, 0xA354B48A]) + return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_8192bit.json b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_8192bit.json new file mode 100644 index 00000000..2b7ea18b --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_8192bit.json @@ -0,0 +1,15 @@ +{ + "entry_point": "./tests/tests/hashes/blake2/blake2s_8192bit.zok", + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_8192bit.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_8192bit.zok new file mode 100644 index 00000000..f7a93b80 --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_8192bit.zok @@ -0,0 +1,6 @@ +import "hashes/blake2/blake2s" + +def main(): + u32[8] h = blake2s::<16>([[0; 16]; 16]) + assert(h == [0x63665303, 0x046C502A, 0xC8514A5D, 0x67B7E833, 0xA9DAD591, 0xB421A8BC, 0x662A73A2, 0x2DA25AFB]) + return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.json b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.json new file mode 100644 index 00000000..2412b8f3 --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.json @@ -0,0 +1,15 @@ +{ + "entry_point": "./tests/tests/hashes/blake2/blake2s_p.zok", + "tests": [ + { + "input": { + "values": [] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok new file mode 100644 index 00000000..7a861e6c --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/blake2/blake2s_p.zok @@ -0,0 +1,6 @@ +import "hashes/blake2/blake2s_p" as blake2s + +def main(): + u32[8] h = blake2s::<1>([[0; 16]], [0x12345678, 0]) + assert(h == [0xC63C8C31, 0x5FCA3E69, 0x13850D46, 0x1DE48657, 0x208D2534, 0x9AA6E0EF, 0xAFEE7610, 0xFBDFAC13]) + return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R10.json b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7.json similarity index 78% rename from zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R10.json rename to zokrates_stdlib/tests/tests/hashes/mimc7/mimc7.json index b37df2d6..fe7581b2 100644 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R10.json +++ b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7.json @@ -1,5 +1,5 @@ { - "entry_point": "./tests/tests/hashes/mimc7/mimc7R10.zok", + "entry_point": "./tests/tests/hashes/mimc7/mimc7.zok", "tests": [ { "input": { diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7.zok b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7.zok new file mode 100644 index 00000000..5303eb7c --- /dev/null +++ b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7.zok @@ -0,0 +1,15 @@ +import "hashes/mimc7/mimc7" + +def main(): + assert(mimc7::<10>(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289) + assert(mimc7::<10>(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145) + + assert(mimc7::<20>(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829) + assert(mimc7::<20>(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758) + + assert(mimc7::<50>(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173) + assert(mimc7::<50>(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789) + + assert(mimc7::<90>(0, 0) == 20281265111705407344053532742843085357648991805359414661661476832595822221514) + assert(mimc7::<90>(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636) + return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R10.zok b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R10.zok deleted file mode 100644 index d41fc8c0..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R10.zok +++ /dev/null @@ -1,6 +0,0 @@ -import "hashes/mimc7/mimc7R10" - -def main(): - assert(mimc7R10(0, 0) == 6004544488495356385698286530147974336054653445122716140990101827963729149289) - assert(mimc7R10(100, 0) == 2977550761518141183167168643824354554080911485709001361112529600968315693145) - return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R20.json b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R20.json deleted file mode 100644 index d5b121ba..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R20.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "entry_point": "./tests/tests/hashes/mimc7/mimc7R20.zok", - "tests": [ - { - "input": { - "values": [] - }, - "output": { - "Ok": { - "values": [] - } - } - } - ] -} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R20.zok b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R20.zok deleted file mode 100644 index 6ef79bbb..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R20.zok +++ /dev/null @@ -1,6 +0,0 @@ -import "hashes/mimc7/mimc7R20" - -def main(): - assert(mimc7R20(0, 0) == 19139739902058628561064841933381604453445216873412991992755775746150759284829) - assert(mimc7R20(100, 0) == 8623418512398828792274158979964869393034224267928014534933203776818702139758) - return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R50.json b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R50.json deleted file mode 100644 index 37933e80..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R50.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "entry_point": "./tests/tests/hashes/mimc7/mimc7R50.zok", - "tests": [ - { - "input": { - "values": [] - }, - "output": { - "Ok": { - "values": [] - } - } - } - ] -} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R50.zok b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R50.zok deleted file mode 100644 index 2f6e513f..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R50.zok +++ /dev/null @@ -1,6 +0,0 @@ -import "hashes/mimc7/mimc7R50" - -def main(): - assert(mimc7R50(0, 0) == 3049953358280347916081509186284461274525472221619157672645224540758481713173) - assert(mimc7R50(100, 0) == 18511388995652647480418174218630545482006454713617579894396683237092568946789) - return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R90.json b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R90.json deleted file mode 100644 index b5c304e0..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R90.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "entry_point": "./tests/tests/hashes/mimc7/mimc7R90.zok", - "tests": [ - { - "input": { - "values": [] - }, - "output": { - "Ok": { - "values": [] - } - } - } - ] -} \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R90.zok b/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R90.zok deleted file mode 100644 index 0f197d0c..00000000 --- a/zokrates_stdlib/tests/tests/hashes/mimc7/mimc7R90.zok +++ /dev/null @@ -1,6 +0,0 @@ -import "hashes/mimc7/mimc7R90" - -def main(): - assert(mimc7R90(0, 0) == 20281265111705407344053532742843085357648991805359414661661476832595822221514) - assert(mimc7R90(100, 0) == 1010054095264022068840870550831559811104631937745987065544478027572003292636) - return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok b/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok index 4ce6771b..decdfeed 100644 --- a/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok +++ b/zokrates_stdlib/tests/tests/hashes/mimcSponge/mimcSponge.zok @@ -1,6 +1,6 @@ import "hashes/mimcSponge/mimcSponge" as mimcSponge def main(): - assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084, 13871743498877225461925335509899475799121918157213219438898506786048812913771, 21633608428713573518356618235457250173701815120501233429160399974209848779097]) - assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929, 6046202021237334713296073963481784771443313518730771623154467767602059802325, 16227963524034219233279650312501310147918176407385833422019760797222680144279]) + assert(mimcSponge([1,2], 3) == [20225509322021146255705869525264566735642015554514977326536820959638320229084,13871743498877225461925335509899475799121918157213219438898506786048812913771,21633608428713573518356618235457250173701815120501233429160399974209848779097f]) + assert(mimcSponge([0,0], 0) == [20636625426020718969131298365984859231982649550971729229988535915544421356929,6046202021237334713296073963481784771443313518730771623154467767602059802325,16227963524034219233279650312501310147918176407385833422019760797222680144279f]) return \ No newline at end of file diff --git a/zokrates_stdlib/tests/tests/utils/casts/to_bits.zok b/zokrates_stdlib/tests/tests/utils/casts/to_bits.zok index 6e445704..864083f2 100644 --- a/zokrates_stdlib/tests/tests/utils/casts/to_bits.zok +++ b/zokrates_stdlib/tests/tests/utils/casts/to_bits.zok @@ -7,7 +7,7 @@ def main(u32[4] a, u16[4] b, u8[4] c) -> (bool[4][32], bool[4][16], bool[4][8]): bool[4][16] e = [[false; 16]; 4] bool[4][8] f = [[false; 8]; 4] - for field i in 0..4 do + for u32 i in 0..4 do d[i] = u32_to_bits(a[i]) e[i] = u16_to_bits(b[i]) f[i] = u8_to_bits(c[i]) diff --git a/zokrates_stdlib/tests/tests/utils/casts/to_field.zok b/zokrates_stdlib/tests/tests/utils/casts/to_field.zok index 24851681..2fc5a9fe 100644 --- a/zokrates_stdlib/tests/tests/utils/casts/to_field.zok +++ b/zokrates_stdlib/tests/tests/utils/casts/to_field.zok @@ -7,13 +7,10 @@ def main(u32[4] a, u16[4] b, u8[4] c) -> (field[4], field[4], field[4]): field[4] e = [0; 4] field[4] f = [0; 4] - for field i in 0..4 do - field g = u32_to_field(a[i]) - d[i] = g - field h = u16_to_field(b[i]) - e[i] = h - field j = u8_to_field(c[i]) - f[i] = j + for u32 i in 0..4 do + d[i] = u32_to_field(a[i]) + e[i] = u16_to_field(b[i]) + f[i] = u8_to_field(c[i]) endfor return d, e, f \ No newline at end of file diff --git a/zokrates_test/src/lib.rs b/zokrates_test/src/lib.rs index 47ae1644..ebcd9556 100644 --- a/zokrates_test/src/lib.rs +++ b/zokrates_test/src/lib.rs @@ -16,7 +16,7 @@ enum Curve { #[derive(Serialize, Deserialize, Clone)] struct Tests { - pub entry_point: PathBuf, + pub entry_point: Option, pub curves: Option>, pub max_constraint_count: Option, pub tests: Vec, @@ -93,7 +93,15 @@ pub fn test_inner(test_path: &str) { let t: Tests = serde_json::from_reader(BufReader::new(File::open(Path::new(test_path)).unwrap())).unwrap(); - let curves = t.curves.clone().unwrap_or(vec![Curve::Bn128]); + let curves = t.curves.clone().unwrap_or_else(|| vec![Curve::Bn128]); + + let t = Tests { + entry_point: Some( + t.entry_point + .unwrap_or_else(|| PathBuf::from(String::from(test_path)).with_extension("zok")), + ), + ..t + }; // this function typically runs in a spawn thread whose stack size is small, leading to stack overflows // to avoid that, run the stack-heavy bit in a thread with a larger stack (8M) @@ -116,13 +124,16 @@ pub fn test_inner(test_path: &str) { } fn compile_and_run(t: Tests) { - let code = std::fs::read_to_string(&t.entry_point).unwrap(); + let entry_point = t.entry_point.unwrap(); + + let code = std::fs::read_to_string(&entry_point).unwrap(); let stdlib = std::fs::canonicalize("../zokrates_stdlib/stdlib").unwrap(); let resolver = FileSystemResolver::with_stdlib_root(stdlib.to_str().unwrap()); + let artifacts = compile::( code, - t.entry_point.clone(), + entry_point.clone(), Some(&resolver), &CompileConfig::default(), ) @@ -130,17 +141,14 @@ fn compile_and_run(t: Tests) { let bin = artifacts.prog(); - match t.max_constraint_count { - Some(target_count) => { - let count = bin.constraint_count(); + if let Some(target_count) = t.max_constraint_count { + let count = bin.constraint_count(); - println!( - "{} at {}% of max", - t.entry_point.display(), - (count as f32) / (target_count as f32) * 100_f32 - ); - } - _ => {} + println!( + "{} at {}% of max", + entry_point.display(), + (count as f32) / (target_count as f32) * 100_f32 + ); }; let interpreter = zokrates_core::ir::Interpreter::default(); @@ -148,25 +156,25 @@ fn compile_and_run(t: Tests) { for test in t.tests.into_iter() { let input = &test.input.values; - let output = interpreter.execute(bin, &(input.iter().cloned().map(parse_val).collect())); + let output = interpreter.execute( + bin, + &(input.iter().cloned().map(parse_val).collect::>()), + ); - match compare(output, test.output) { - Err(e) => { - let mut code = File::open(&t.entry_point).unwrap(); - let mut s = String::new(); - code.read_to_string(&mut s).unwrap(); - let context = format!( - "\n{}\nCalled with input ({})\n", - s, - input - .iter() - .map(|i| format!("{}", i)) - .collect::>() - .join(", ") - ); - panic!("{}{}", context, e) - } - Ok(..) => {} - }; + if let Err(e) = compare(output, test.output) { + let mut code = File::open(&entry_point).unwrap(); + let mut s = String::new(); + code.read_to_string(&mut s).unwrap(); + let context = format!( + "\n{}\nCalled with input ({})\n", + s, + input + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(", ") + ); + panic!("{}{}", context, e) + } } }