From a1a65378a7d1ebac88a5793a41efaca3c625e496 Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 16 May 2021 19:16:25 +0200 Subject: [PATCH] introduce block into ast and implement isolation on that --- test.zok | 32 ----- zokrates_cli/examples/for.zok | 4 +- zokrates_core/src/compile.rs | 4 - zokrates_core/src/flatten/mod.rs | 18 --- zokrates_core/src/semantics.rs | 2 +- .../src/static_analysis/bounds_checker.rs | 4 +- .../src/static_analysis/branch_isolator.rs | 102 ++++++++++++++++ .../static_analysis/flatten_complex_types.rs | 7 +- zokrates_core/src/static_analysis/mod.rs | 4 + .../src/static_analysis/propagation.rs | 14 ++- .../src/static_analysis/reducer/mod.rs | 112 +++++++++++------- zokrates_core/src/typed_absy/folder.rs | 48 +++++--- zokrates_core/src/typed_absy/integer.rs | 2 +- zokrates_core/src/typed_absy/mod.rs | 106 +++++++---------- zokrates_core/src/typed_absy/result_folder.rs | 56 ++++++--- .../tests/tests/panics/loop_bound.json | 33 ++++++ .../tests/tests/panics/loop_bound.zok | 9 ++ .../tests/tests/panics/panic_isolation.json | 64 ++++++++++ .../tests/tests/panics/panic_isolation.zok | 31 +++++ zokrates_test/src/lib.rs | 3 +- 20 files changed, 448 insertions(+), 207 deletions(-) delete mode 100644 test.zok create mode 100644 zokrates_core/src/static_analysis/branch_isolator.rs create mode 100644 zokrates_core_test/tests/tests/panics/loop_bound.json create mode 100644 zokrates_core_test/tests/tests/panics/loop_bound.zok create mode 100644 zokrates_core_test/tests/tests/panics/panic_isolation.json create mode 100644 zokrates_core_test/tests/tests/panics/panic_isolation.zok diff --git a/test.zok b/test.zok deleted file mode 100644 index 583be688..00000000 --- a/test.zok +++ /dev/null @@ -1,32 +0,0 @@ -def zero(field x) -> field: - assert(x == 0) - return 0 - -def inverse(field x) -> field: - assert(x != 0) - return 1/x - -def main(field x) -> field: - return if x == 0 then zero(x) else inverse(x) fi - -// def yes(bool x) -> bool: -// assert(x) -// return x - -// def no(bool x) -> bool: -// assert(!x) -// return !x - -// def main(bool x) -> bool: -// return if x then yes(x) else no(x) fi - -// def ones(field[2] a) -> field[2]: -// assert(a == [1, 1]) -// return a - -// def twos(field[2] a) -> field[2]: -// assert(a == [2, 2]) -// return a - -// def main(bool condition, field[2] a, field[2] b) -> field[2]: -// return if condition then ones(a) else twos(b) fi \ No newline at end of file diff --git a/zokrates_cli/examples/for.zok b/zokrates_cli/examples/for.zok index 7bdf2b8f..60187e4b 100644 --- a/zokrates_cli/examples/for.zok +++ b/zokrates_cli/examples/for.zok @@ -1,10 +1,10 @@ def bound(field x) -> u32: - return 41 + 1 + return 41 + x def main(field a) -> field: field x = 7 x = x + 1 - for u32 i in 0..bound(x) do + for u32 i in 0..bound(x) + bound(x + 1) do // x = x + a x = x + a endfor diff --git a/zokrates_core/src/compile.rs b/zokrates_core/src/compile.rs index bde13f90..87d27429 100644 --- a/zokrates_core/src/compile.rs +++ b/zokrates_core/src/compile.rs @@ -173,13 +173,9 @@ pub fn compile>( let (typed_ast, abi) = check_with_arena(source, location, resolver, &arena)?; - println!("{}", typed_ast); - // flatten input program let program_flattened = Flattener::flatten(typed_ast, config); - println!("{}", program_flattened); - // analyse (constant propagation after call resolution) let program_flattened = program_flattened.analyse(); diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 9c2802bf..6d06f526 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -450,15 +450,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { let condition_id = self.use_sym(); statements_flattened.push(FlatStatement::Definition(condition_id, condition)); - println!( - "BEFORE\n {}\n", - alternative_statements - .iter() - .map(|s| s.to_string()) - .collect::>() - .join("\n") - ); - let consequence_statements = self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( @@ -469,15 +460,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { ), ); - println!( - "AFTER\n {}\n", - alternative_statements - .iter() - .map(|s| s.to_string()) - .collect::>() - .join("\n") - ); - statements_flattened.extend(consequence_statements); statements_flattened.extend(alternative_statements); diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 1a6a36b8..8047c497 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1963,7 +1963,7 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Boolean(condition) => { match (consequence_checked, alternative_checked) { (TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => { - Ok(FieldElementExpression::IfElse(box condition, box FieldElementExpression::Block(vec![], box consequence), box FieldElementExpression::Block(vec![], box alternative)).into()) + Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into()) }, (TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => { Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into()) diff --git a/zokrates_core/src/static_analysis/bounds_checker.rs b/zokrates_core/src/static_analysis/bounds_checker.rs index de43d5f0..e6e6ee41 100644 --- a/zokrates_core/src/static_analysis/bounds_checker.rs +++ b/zokrates_core/src/static_analysis/bounds_checker.rs @@ -19,7 +19,7 @@ impl BoundsChecker { let array = self.fold_array_expression(array)?; let index = self.fold_uint_expression(index)?; - match (array.get_array_type().size.as_inner(), index.as_inner()) { + match (array.ty().size.as_inner(), index.as_inner()) { (UExpressionInner::Value(size), UExpressionInner::Value(index)) => { if index >= size { return Err(format!( @@ -53,7 +53,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for BoundsChecker { let to = self.fold_uint_expression(to)?; match ( - array.get_array_type().size.as_inner(), + array.ty().size.as_inner(), from.as_inner(), to.as_inner(), ) { diff --git a/zokrates_core/src/static_analysis/branch_isolator.rs b/zokrates_core/src/static_analysis/branch_isolator.rs new file mode 100644 index 00000000..4bc4aab5 --- /dev/null +++ b/zokrates_core/src/static_analysis/branch_isolator.rs @@ -0,0 +1,102 @@ +// Isolate branches means making sure that any branch is enclosed in a block. +// This is important, because we want any statement resulting from inlining any branch to be isolated from the coller, so that its panics can be conditional to the branch being logically run + +// `if c then a else b fi` becomes `if c then { a } else { b } fi`, and down the line any statements resulting from trating `a` and `b` can be safely kept inside the respective blocks. + +use crate::typed_absy::folder::*; +use crate::typed_absy::*; +use zokrates_field::Field; + +pub struct Isolator; + +impl Isolator { + pub fn isolate(p: TypedProgram) -> TypedProgram { + let mut isolator = Isolator; + isolator.fold_program(p) + } +} + +impl<'ast, T: Field> Folder<'ast, T> for Isolator { + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> FieldElementExpression<'ast, T> { + match e { + FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { + FieldElementExpression::IfElse( + box self.fold_boolean_expression(condition), + box FieldElementExpression::block(vec![],consequence), + box FieldElementExpression::block(vec![],alternative), + ) + } + e => fold_field_expression(self, e), + } + } + + fn fold_boolean_expression( + &mut self, + e: BooleanExpression<'ast, T>, + ) -> BooleanExpression<'ast, T> { + match e { + BooleanExpression::IfElse(box condition, box consequence, box alternative) => { + BooleanExpression::IfElse( + box self.fold_boolean_expression(condition), + box BooleanExpression::block(vec![],consequence), + box BooleanExpression::block(vec![],alternative), + ) + } + e => fold_boolean_expression(self, e), + } + } + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> UExpressionInner<'ast, T> { + match e { + UExpressionInner::IfElse(box condition, box consequence, box alternative) => { + UExpressionInner::IfElse( + box self.fold_boolean_expression(condition), + box UExpression::block(vec![],consequence), + box UExpression::block(vec![],alternative), + ) + } + e => fold_uint_expression_inner(self, bitwidth, e), + } + } + + fn fold_array_expression_inner( + &mut self, + array_ty: &ArrayType<'ast, T>, + e: ArrayExpressionInner<'ast, T>, + ) -> ArrayExpressionInner<'ast, T> { + match e { + ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { + ArrayExpressionInner::IfElse( + box self.fold_boolean_expression(condition), + box ArrayExpression::block(vec![],consequence), + box ArrayExpression::block(vec![],alternative), + ) + } + e => fold_array_expression_inner(self, array_ty, e), + } + } + + fn fold_struct_expression_inner( + &mut self, + struct_ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> StructExpressionInner<'ast, T> { + match e { + StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { + StructExpressionInner::IfElse( + box self.fold_boolean_expression(condition), + box StructExpression::block(vec![],consequence), + box StructExpression::block(vec![],alternative), + ) + } + e => fold_struct_expression_inner(self, struct_ty, e), + } + } +} diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index b89cc635..2c7eea72 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -683,11 +683,12 @@ pub fn fold_field_expression<'ast, T: Field>( _ => unreachable!(""), } } - typed_absy::FieldElementExpression::Block(statements, box value) => { - statements + typed_absy::FieldElementExpression::Block(block) => { + block + .statements .into_iter() .for_each(|s| f.fold_statement(statements_buffer, s)); - f.fold_field_expression(statements_buffer, value) + f.fold_field_expression(statements_buffer, *block.value) } } } diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index aaecb49c..7ee46057 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -5,6 +5,7 @@ //! @date 2018 mod bounds_checker; +mod branch_isolator; mod constant_inliner; mod flat_propagation; mod flatten_complex_types; @@ -17,6 +18,7 @@ mod variable_read_remover; mod variable_write_remover; use self::bounds_checker::BoundsChecker; +use self::branch_isolator::Isolator; use self::flatten_complex_types::Flattener; use self::propagation::Propagator; use self::reducer::reduce_program; @@ -75,6 +77,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { pub fn analyse(self) -> Result<(ZirProgram<'ast, T>, Abi), Error> { // inline user-defined constants let r = ConstantInliner::inline(self); + // isolate branches + let r = Isolator::isolate(r); // reduce the program to a single function let r = reduce_program(r).map_err(Error::from)?; // generate abi diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 92ff8f79..8f5ef25d 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -147,7 +147,15 @@ fn is_constant(e: &TypedExpression) -> bool { StructExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)), _ => false, }, - TypedExpression::Uint(a) => matches!(a.as_inner(), UExpressionInner::Value(..)), + TypedExpression::Uint(a) => { + matches!(a.as_inner(), UExpressionInner::Value(..)) + || match a.as_inner() { + UExpressionInner::Block(_, e) => { + is_constant(&TypedExpression::from(*e.clone())) + } + _ => false, + } + } _ => false, } } @@ -167,7 +175,7 @@ fn remove_spreads(e: TypedExpression) -> TypedExpression { match e { TypedExpression::Array(a) => { - let array_ty = a.get_array_type(); + let array_ty = a.ty(); match a.into_inner() { ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( @@ -353,8 +361,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { let expression_list = self.fold_expression_list(expression_list)?; match expression_list { - l @ TypedExpressionList::Block(..) => fold_expression_list(self, l) - .map(|l| vec![TypedStatement::MultipleDefinition(assignees, l)]), TypedExpressionList::EmbedCall(embed, generics, arguments, types) => { let arguments: Vec<_> = arguments .into_iter() diff --git a/zokrates_core/src/static_analysis/reducer/mod.rs b/zokrates_core/src/static_analysis/reducer/mod.rs index ee5b14b9..ad42c9b8 100644 --- a/zokrates_core/src/static_analysis/reducer/mod.rs +++ b/zokrates_core/src/static_analysis/reducer/mod.rs @@ -22,11 +22,11 @@ use crate::typed_absy::Folder; use std::collections::HashMap; use crate::typed_absy::{ - ArrayExpression, ArrayExpressionInner, ArrayType, Block, BooleanExpression, CoreIdentifier, - DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, StructExpression, - StructExpressionInner, Type, Typed, TypedExpression, TypedExpressionList, TypedFunction, - TypedFunctionSymbol, TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, - Variable, + ArrayExpression, ArrayExpressionInner, ArrayType, BlockExpression, BooleanExpression, + CoreIdentifier, DeclarationFunctionKey, FieldElementExpression, FunctionCall, Identifier, + StructExpression, StructExpressionInner, StructType, Type, TypedExpression, + TypedExpressionList, TypedFunction, TypedFunctionSymbol, TypedModule, TypedProgram, + TypedStatement, UBitwidth, UExpression, UExpressionInner, Variable, }; use std::convert::{TryFrom, TryInto}; @@ -200,10 +200,7 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { output_type: Type<'ast, T>, ) -> Result where - E: Block<'ast, T> - + FunctionCall<'ast, T> - + TryFrom, Error = ()> - + std::fmt::Debug, + E: FunctionCall<'ast, T> + TryFrom, Error = ()> + std::fmt::Debug, { let generics = generics .into_iter() @@ -227,11 +224,8 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { match res { Ok(Output::Complete((statements, mut expressions))) => { self.complete &= true; - Ok(E::block( - statements, - expressions.pop().unwrap().try_into().unwrap(), - output_type, - )) + self.statement_buffer.extend(statements); + Ok(expressions.pop().unwrap().try_into().unwrap()) } Ok(Output::Incomplete((statements, expressions), delta_for_loop_versions)) => { self.complete = false; @@ -280,6 +274,29 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { type Error = Error; + fn fold_block_expression>( + &mut self, + b: BlockExpression<'ast, T, E>, + ) -> Result, Self::Error> { + // backup the statements and continue with a fresh state + let statement_buffer = std::mem::take(&mut self.statement_buffer); + + let block = fold_block_expression(self, b)?; + + // put the original statements back and extract the statements created by visiting the block + let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); + + // return the visited block, augmented with the statements created while visiting it + Ok(BlockExpression { + statements: block + .statements + .into_iter() + .chain(extra_statements) + .collect(), + ..block + }) + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, @@ -433,7 +450,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { s => fold_statement(self, s), }; - res.map(|res| self.statement_buffer.drain(..).chain(res).collect()) + //res.map(|res| self.statement_buffer.drain(..).chain(res).collect()) + res } fn fold_boolean_expression( @@ -448,18 +466,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { } } - fn fold_uint_expression( + fn fold_uint_expression_inner( &mut self, - e: UExpression<'ast, T>, - ) -> Result, Self::Error> { - match e.as_inner() { - UExpressionInner::FunctionCall(key, generics, arguments) => self.fold_function_call( - key.clone(), - generics.clone(), - arguments.clone(), - e.get_type(), - ), - _ => fold_uint_expression(self, e), + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { + UExpressionInner::FunctionCall(key, generics, arguments) => self + .fold_function_call::>( + key, + generics, + arguments, + Type::Uint(bitwidth), + ) + .map(|e| e.into_inner()), + e => fold_uint_expression_inner(self, bitwidth, e), } } @@ -477,7 +498,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { fn fold_array_expression_inner( &mut self, - ty: &ArrayType<'ast, T>, + array_ty: &ArrayType<'ast, T>, e: ArrayExpressionInner<'ast, T>, ) -> Result, Self::Error> { match e { @@ -486,7 +507,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { key.clone(), generics, arguments.clone(), - Type::array(ty.clone()), + Type::array(array_ty.clone()), ) .map(|e| e.into_inner()), ArrayExpressionInner::Slice(box array, box from, box to) => { @@ -504,23 +525,25 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { } } } - _ => fold_array_expression_inner(self, &ty, e), + _ => fold_array_expression_inner(self, &array_ty, e), } } - fn fold_struct_expression( + fn fold_struct_expression_inner( &mut self, - e: StructExpression<'ast, T>, - ) -> Result, Self::Error> { - match e.as_inner() { + struct_ty: &StructType<'ast, T>, + e: StructExpressionInner<'ast, T>, + ) -> Result, Self::Error> { + match e { StructExpressionInner::FunctionCall(key, generics, arguments) => self - .fold_function_call( - key.clone(), - generics.clone(), - arguments.clone(), - e.get_type(), - ), - _ => fold_struct_expression(self, e), + .fold_function_call::>( + key, + generics, + arguments, + Type::Struct(struct_ty.clone()), + ) + .map(|e| e.into_inner()), + _ => fold_struct_expression_inner(self, struct_ty, e), } } } @@ -597,7 +620,14 @@ fn reduce_function<'ast, T: Field>( statements: f .statements .into_iter() - .map(|s| reducer.fold_statement(s)) + .map(|s| { + let res = reducer.fold_statement(s)?; + Ok(reducer + .statement_buffer + .drain(..) + .chain(res) + .collect::>()) + }) .collect::, _>>()? .into_iter() .flatten() diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index 35ff94e5..ff5175f6 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -4,6 +4,16 @@ use crate::typed_absy::types::{ArrayType, StructMember, StructType}; use crate::typed_absy::*; use zokrates_field::Field; +pub trait Fold<'ast, T: Field>: Sized { + fn fold>(self, f: &mut F) -> Self; +} + +impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Self { + f.fold_field_expression(self) + } +} + pub trait Folder<'ast, T: Field>: Sized { fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { fold_program(self, p) @@ -137,6 +147,13 @@ pub trait Folder<'ast, T: Field>: Sized { } } + fn fold_block_expression>( + &mut self, + block: BlockExpression<'ast, T, E>, + ) -> BlockExpression<'ast, T, E> { + fold_block_expression(self, block) + } + fn fold_array_expression(&mut self, e: ArrayExpression<'ast, T>) -> ArrayExpression<'ast, T> { fold_array_expression(self, e) } @@ -358,13 +375,9 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( e: FieldElementExpression<'ast, T>, ) -> FieldElementExpression<'ast, T> { match e { - FieldElementExpression::Block(statements, box value) => FieldElementExpression::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - box f.fold_field_expression(value), - ), + FieldElementExpression::Block(block) => { + FieldElementExpression::Block(f.fold_block_expression(block)) + } FieldElementExpression::Number(n) => FieldElementExpression::Number(n), FieldElementExpression::Identifier(id) => { FieldElementExpression::Identifier(f.fold_name(id)) @@ -688,6 +701,20 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_block_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>( + f: &mut F, + block: BlockExpression<'ast, T, E>, +) -> BlockExpression<'ast, T, E> { + BlockExpression { + statements: block + .statements + .into_iter() + .flat_map(|s| f.fold_statement(s)) + .collect(), + value: box block.value.fold(f), + } +} + pub fn fold_function<'ast, T: Field, F: Folder<'ast, T>>( f: &mut F, fun: TypedFunction<'ast, T>, @@ -749,13 +776,6 @@ pub fn fold_expression_list<'ast, T: Field, F: Folder<'ast, T>>( types.into_iter().map(|t| f.fold_type(t)).collect(), ) } - TypedExpressionList::Block(statements, values) => TypedExpressionList::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - values.into_iter().map(|v| f.fold_expression(v)).collect(), - ), } } diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index 61a18748..486f6bb5 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -481,7 +481,7 @@ impl<'ast, T: Field> ArrayExpression<'ast, T> { array: Self, target_inner_ty: Type<'ast, T>, ) -> Result> { - let array_ty = array.get_array_type(); + let array_ty = array.ty(); // elements must fit in the target type match array.into_inner() { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 2d9b7464..d84d3083 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -489,13 +489,7 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { TypedStatement::Declaration(ref var) => write!(f, "{}", var), TypedStatement::Definition(ref lhs, ref rhs) => write!(f, "{} = {}", lhs, rhs), TypedStatement::Assertion(ref e) => write!(f, "assert({})", e), - TypedStatement::For(ref var, ref start, ref stop, ref list) => { - writeln!(f, "for {} in {}..{} do", var, start, stop)?; - for l in list { - writeln!(f, "\t\t{}", l)?; - } - write!(f, "\tendfor") - } + TypedStatement::For(..) => unreachable!("fmt_indented should be called instead"), TypedStatement::MultipleDefinition(ref ids, ref rhs) => { for (i, id) in ids.iter().enumerate() { write!(f, "{}", id)?; @@ -713,7 +707,6 @@ pub enum TypedExpressionList<'ast, T> { Vec>, Vec>, ), - Block(Vec>, Vec>), } impl<'ast, T: Field> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> { @@ -721,9 +714,22 @@ impl<'ast, T: Field> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> { match *self { TypedExpressionList::FunctionCall(_, _, _, ref types) => types.clone(), TypedExpressionList::EmbedCall(_, _, _, ref types) => types.clone(), - TypedExpressionList::Block(_, ref values) => { - values.iter().map(|v| v.get_type()).collect() - } + } + } +} + +#[derive(Clone, PartialEq, Debug, Hash, Eq)] +// a block expression which returns an `E` +pub struct BlockExpression<'ast, T, E> { + pub statements: Vec>, + pub value: Box, +} + +impl<'ast, T, E> BlockExpression<'ast, T, E> { + pub fn new(statements: Vec>, value: E) -> Self { + BlockExpression { + statements, + value: box value, } } } @@ -731,10 +737,7 @@ impl<'ast, T: Field> MultiTyped<'ast, T> for TypedExpressionList<'ast, T> { /// An expression of type `field` #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub enum FieldElementExpression<'ast, T> { - Block( - Vec>, - Box>, - ), + Block(BlockExpression<'ast, T, Self>), Number(T), Identifier(Identifier<'ast>), Add( @@ -925,7 +928,7 @@ impl<'ast, T: Clone> ArrayValue<'ast, T> { TypedExpressionOrSpread::Expression(e) => vec![Some(e.clone())], TypedExpressionOrSpread::Spread(s) => match s.array.size().into_inner() { UExpressionInner::Value(size) => { - let array_ty = s.array.get_array_type().clone(); + let array_ty = s.array.ty().clone(); match s.array.into_inner() { ArrayExpressionInner::Value(v) => v @@ -1036,7 +1039,7 @@ impl<'ast, T: Clone> ArrayExpression<'ast, T> { self.inner } - pub fn get_array_type(&self) -> ArrayType<'ast, T> { + pub fn ty(&self) -> ArrayType<'ast, T> { ArrayType { size: self.size(), ty: box self.inner_type().clone(), @@ -1233,19 +1236,25 @@ impl<'ast, T> TryFrom> for IntExpression<'ast, T> { } } +impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for BlockExpression<'ast, T, E> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{{\n{}\n}}", + self.statements + .iter() + .map(|s| s.to_string()) + .chain(std::iter::once(self.value.to_string())) + .collect::>() + .join("\n") + ) + } +} + impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - FieldElementExpression::Block(ref statements, ref value) => write!( - f, - "{{{}}}", - statements - .iter() - .map(|s| s.to_string()) - .chain(std::iter::once(value.to_string())) - .collect::>() - .join("\n") - ), + FieldElementExpression::Block(ref block) => write!(f, "{}", block), FieldElementExpression::Number(ref i) => write!(f, "{}f", i), FieldElementExpression::Identifier(ref var) => write!(f, "{}", var), FieldElementExpression::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), @@ -1542,22 +1551,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedExpressionList<'ast, T> { } write!(f, ")") } - TypedExpressionList::Block(ref statements, ref values) => write!( - f, - "{{{}}}", - statements - .iter() - .map(|s| s.to_string()) - .chain(std::iter::once( - values - .iter() - .map(|v| v.to_string()) - .collect::>() - .join(", ") - )) - .collect::>() - .join("\n") - ), } } } @@ -1676,7 +1669,7 @@ impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> { impl<'ast, T: Clone> Select<'ast, T> for TypedExpression<'ast, T> { fn select>>(array: ArrayExpression<'ast, T>, index: I) -> Self { - match *array.get_array_type().ty { + match *array.ty().ty { Type::Array(..) => ArrayExpression::select(array, index).into(), Type::Struct(..) => StructExpression::select(array, index).into(), Type::FieldElement => FieldElementExpression::select(array, index).into(), @@ -1880,7 +1873,6 @@ pub trait Block<'ast, T> { fn block( statements: Vec>, value: Self, - output_type: Type<'ast, T>, ) -> Self; } @@ -1888,10 +1880,8 @@ impl<'ast, T: Field> Block<'ast, T> for FieldElementExpression<'ast, T> { fn block( statements: Vec>, value: Self, - output_type: Type<'ast, T>, ) -> Self { - assert_eq!(output_type, Type::FieldElement); - FieldElementExpression::Block(statements, box value) + FieldElementExpression::Block(BlockExpression::new(statements, value)) } } @@ -1899,9 +1889,7 @@ impl<'ast, T: Field> Block<'ast, T> for BooleanExpression<'ast, T> { fn block( statements: Vec>, value: Self, - output_type: Type<'ast, T>, ) -> Self { - assert_eq!(output_type, Type::Boolean); BooleanExpression::Block(statements, box value) } } @@ -1910,12 +1898,8 @@ impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> { fn block( statements: Vec>, value: Self, - output_type: Type<'ast, T>, ) -> Self { - let bitwidth = match output_type { - Type::Uint(bitwidth) => bitwidth, - _ => unreachable!(), - }; + let bitwidth = value.bitwidth(); UExpressionInner::Block(statements, box value).annotate(bitwidth) } } @@ -1924,12 +1908,8 @@ impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> { fn block( statements: Vec>, value: Self, - output_type: Type<'ast, T>, ) -> Self { - let array_ty = match output_type { - Type::Array(array_ty) => array_ty, - _ => unreachable!(), - }; + let array_ty = value.ty(); ArrayExpressionInner::Block(statements, box value).annotate(*array_ty.ty, array_ty.size) } } @@ -1938,12 +1918,8 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> { fn block( statements: Vec>, value: Self, - output_type: Type<'ast, T>, ) -> Self { - let struct_ty = match output_type { - Type::Struct(struct_ty) => struct_ty, - _ => unreachable!(), - }; + let struct_ty = value.ty().clone(); StructExpressionInner::Block(statements, box value).annotate(struct_ty) } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index f2aaaf61..25fff6a5 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -4,6 +4,16 @@ use crate::typed_absy::types::{ArrayType, StructMember, StructType}; use crate::typed_absy::*; use zokrates_field::Field; +pub trait ResultFold<'ast, T: Field>: Sized { + fn fold>(self, f: &mut F) -> Result; +} + +impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Result { + f.fold_field_expression(self) + } +} + pub trait ResultFolder<'ast, T: Field>: Sized { type Error; @@ -90,6 +100,13 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } } + fn fold_block_expression>( + &mut self, + block: BlockExpression<'ast, T, E>, + ) -> Result, Self::Error> { + fold_block_expression(self, block) + } + fn fold_array_type( &mut self, t: ArrayType<'ast, T>, @@ -418,14 +435,9 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( e: FieldElementExpression<'ast, T>, ) -> Result, F::Error> { let e = match e { - FieldElementExpression::Block(statements, box value) => FieldElementExpression::Block( - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>() - .map(|r| r.into_iter().flatten().collect())?, - box f.fold_field_expression(value)?, - ), + FieldElementExpression::Block(block) => { + FieldElementExpression::Block(f.fold_block_expression(block)?) + } FieldElementExpression::Number(n) => FieldElementExpression::Number(n), FieldElementExpression::Identifier(id) => { FieldElementExpression::Identifier(f.fold_name(id)?) @@ -502,6 +514,23 @@ pub fn fold_int_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( unreachable!() } +pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>( + f: &mut F, + block: BlockExpression<'ast, T, E>, +) -> Result, F::Error> { + Ok(BlockExpression { + statements: block + .statements + .into_iter() + .map(|s| f.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), + value: box block.value.fold(f)?, + }) +} + pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( f: &mut F, e: BooleanExpression<'ast, T>, @@ -834,17 +863,6 @@ pub fn fold_expression_list<'ast, T: Field, F: ResultFolder<'ast, T>>( .collect::>()?, )) } - TypedExpressionList::Block(statements, values) => Ok(TypedExpressionList::Block( - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>() - .map(|v| v.into_iter().flatten().collect())?, - values - .into_iter() - .map(|v| f.fold_expression(v)) - .collect::>()?, - )), } } diff --git a/zokrates_core_test/tests/tests/panics/loop_bound.json b/zokrates_core_test/tests/tests/panics/loop_bound.json new file mode 100644 index 00000000..77f44c82 --- /dev/null +++ b/zokrates_core_test/tests/tests/panics/loop_bound.json @@ -0,0 +1,33 @@ +{ + "entry_point": "./tests/tests/panics/loop_bound.zok", + "curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"], + "tests": [ + { + "input": { + "values": [ + "0" + ] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "0", + "right": "1" + } + } + } + }, + { + "input": { + "values": [ + "1" + ] + }, + "output": { + "Ok": { + "values": [] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/panics/loop_bound.zok b/zokrates_core_test/tests/tests/panics/loop_bound.zok new file mode 100644 index 00000000..89f8e816 --- /dev/null +++ b/zokrates_core_test/tests/tests/panics/loop_bound.zok @@ -0,0 +1,9 @@ +def throwing_bound(u32 x) -> u32: + assert(x == 1) + return 1 + +// Even if the bound is constant at compile time, it can throw at runtime +def main(u32 x): + for u32 i in 0..throwing_bound(x) do + endfor + return diff --git a/zokrates_core_test/tests/tests/panics/panic_isolation.json b/zokrates_core_test/tests/tests/panics/panic_isolation.json new file mode 100644 index 00000000..304266b6 --- /dev/null +++ b/zokrates_core_test/tests/tests/panics/panic_isolation.json @@ -0,0 +1,64 @@ +{ + "entry_point": "./tests/tests/panics/panic_isolation.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": [ + "1", + "42", + "42", + "0" + ] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "1", + "right": "21888242871839275222246405745257275088548364400416034343698204186575808495577" + } + } + } + }, + { + "input": { + "values": [ + "1", + "1", + "1", + "1" + ] + }, + "output": { + "Ok": { + "values": [ + "1", + "1", + "1", + "1" + ] + } + } + }, + { + "input": { + "values": [ + "0", + "2", + "2", + "0" + ] + }, + "output": { + "Ok": { + "values": [ + "0", + "2", + "2", + "0" + ] + } + } + } + ] +} diff --git a/zokrates_core_test/tests/tests/panics/panic_isolation.zok b/zokrates_core_test/tests/tests/panics/panic_isolation.zok new file mode 100644 index 00000000..c95ae723 --- /dev/null +++ b/zokrates_core_test/tests/tests/panics/panic_isolation.zok @@ -0,0 +1,31 @@ +def zero(field x) -> field: + assert(x == 0) + return 0 + +def inverse(field x) -> field: + assert(x != 0) + return 1/x + +def yes(bool x) -> bool: + assert(x) + return x + +def no(bool x) -> bool: + assert(!x) + return x + +def ones(field[2] a) -> field[2]: + assert(a == [1, 1]) + return a + +def twos(field[2] a) -> field[2]: + assert(a == [2, 2]) + return a + +def main(bool condition, field[2] a, field x) -> (bool, field[2], field): + // first branch asserts that `condition` is true, second branch asserts that `condition` is false. This should never throw. + // first branch asserts that all elements in `a` are 1, 2 in the second branch. This should throw only if `a` is neither ones or zeroes + // first branch asserts that `x` is zero and returns it, second branch asserts that `x` isn't 0 and returns its inverse (which internally generates a failing assert if x is 0). This should never throw + return if condition then yes(condition) else no(condition) fi,\ + if condition then ones(a) else twos(a) fi,\ + if x == 0 then zero(x) else inverse(x) fi \ No newline at end of file diff --git a/zokrates_test/src/lib.rs b/zokrates_test/src/lib.rs index 04a57e9b..8730cb8b 100644 --- a/zokrates_test/src/lib.rs +++ b/zokrates_test/src/lib.rs @@ -163,8 +163,9 @@ fn compile_and_run(t: Tests) { let mut s = String::new(); code.read_to_string(&mut s).unwrap(); let context = format!( - "\n{}\nCalled with input ({})\n", + "\n{}\nCalled on curve {} with input ({})\n", s, + T::name(), input .iter() .map(|i| i.to_string())