diff --git a/zokrates_cli/examples/compile_errors/block_result_propagation.zok b/zokrates_cli/examples/compile_errors/block_result_propagation.zok new file mode 100644 index 00000000..82104dbc --- /dev/null +++ b/zokrates_cli/examples/compile_errors/block_result_propagation.zok @@ -0,0 +1,10 @@ +def throwing_bound(u32 x) -> u32: + assert(x == N) + return 1 + +// this should compile: the conditional, even though it can throw, has a constant compile-time value `1` +// the value of the blocks should be propagated out, so that `if x == 0 then 1 else 1 fi` can be determined to be `1` +def main(u32 x): + for u32 i in 0..if x == 0 then throwing_bound::<0>(x) else throwing_bound::<1>(x) fi do + endfor + return diff --git a/zokrates_cli/examples/for.zok b/zokrates_cli/examples/for.zok index 60187e4b..4b60c68e 100644 --- a/zokrates_cli/examples/for.zok +++ b/zokrates_cli/examples/for.zok @@ -1,5 +1,5 @@ def bound(field x) -> u32: - return 41 + x + return 41 + 1 def main(field a) -> field: field x = 7 diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 2c7eea72..6ccd7600 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -206,7 +206,6 @@ impl<'ast, T: Field> Flattener { fn fold_expression_list( &mut self, statements_buffer: &mut Vec>, - es: typed_absy::TypedExpressionList<'ast, T>, ) -> zir::ZirExpressionList<'ast, T> { match es { @@ -227,7 +226,6 @@ impl<'ast, T: Field> Flattener { fn fold_field_expression( &mut self, statements_buffer: &mut Vec>, - e: typed_absy::FieldElementExpression<'ast, T>, ) -> zir::FieldElementExpression<'ast, T> { fold_field_expression(self, statements_buffer, e) @@ -328,11 +326,12 @@ pub fn fold_array_expression_inner<'ast, T: Field>( array: typed_absy::ArrayExpressionInner<'ast, T>, ) -> Vec> { match array { - typed_absy::ArrayExpressionInner::Block(statements, box value) => { - statements + typed_absy::ArrayExpressionInner::Block(block) => { + block + .statements .into_iter() .for_each(|s| f.fold_statement(statements_buffer, s)); - f.fold_array_expression(statements_buffer, value) + f.fold_array_expression(statements_buffer, *block.value) } typed_absy::ArrayExpressionInner::Identifier(id) => { let variables = flatten_identifier_rec( @@ -472,11 +471,12 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( struc: typed_absy::StructExpressionInner<'ast, T>, ) -> Vec> { match struc { - typed_absy::StructExpressionInner::Block(statements, box value) => { - statements + typed_absy::StructExpressionInner::Block(block) => { + block + .statements .into_iter() .for_each(|s| f.fold_statement(statements_buffer, s)); - f.fold_struct_expression(statements_buffer, value) + f.fold_struct_expression(statements_buffer, *block.value) } typed_absy::StructExpressionInner::Identifier(id) => { let variables = flatten_identifier_rec( @@ -699,11 +699,12 @@ pub fn fold_boolean_expression<'ast, T: Field>( e: typed_absy::BooleanExpression<'ast, T>, ) -> zir::BooleanExpression<'ast, T> { match e { - typed_absy::BooleanExpression::Block(statements, box value) => { - statements + typed_absy::BooleanExpression::Block(block) => { + block + .statements .into_iter() .for_each(|s| f.fold_statement(statements_buffer, s)); - f.fold_boolean_expression(statements_buffer, value) + f.fold_boolean_expression(statements_buffer, *block.value) } typed_absy::BooleanExpression::Value(v) => zir::BooleanExpression::Value(v), typed_absy::BooleanExpression::Identifier(id) => zir::BooleanExpression::Identifier( @@ -899,11 +900,12 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( e: typed_absy::UExpressionInner<'ast, T>, ) -> zir::UExpressionInner<'ast, T> { match e { - typed_absy::UExpressionInner::Block(statements, box value) => { - statements + typed_absy::UExpressionInner::Block(block) => { + block + .statements .into_iter() .for_each(|s| f.fold_statement(statements_buffer, s)); - f.fold_uint_expression(statements_buffer, value) + f.fold_uint_expression(statements_buffer, *block.value) .into_inner() } typed_absy::UExpressionInner::Value(v) => zir::UExpressionInner::Value(v), diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 8f5ef25d..0705d538 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -150,8 +150,8 @@ fn is_constant(e: &TypedExpression) -> bool { TypedExpression::Uint(a) => { matches!(a.as_inner(), UExpressionInner::Value(..)) || match a.as_inner() { - UExpressionInner::Block(_, e) => { - is_constant(&TypedExpression::from(*e.clone())) + UExpressionInner::Block(block) => { + is_constant(&TypedExpression::from(*block.value.clone())) } _ => false, } diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index ff5175f6..84342ed5 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -14,6 +14,30 @@ impl<'ast, T: Field> Fold<'ast, T> for FieldElementExpression<'ast, T> { } } +impl<'ast, T: Field> Fold<'ast, T> for BooleanExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Self { + f.fold_boolean_expression(self) + } +} + +impl<'ast, T: Field> Fold<'ast, T> for UExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Self { + f.fold_uint_expression(self) + } +} + +impl<'ast, T: Field> Fold<'ast, T> for StructExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Self { + f.fold_struct_expression(self) + } +} + +impl<'ast, T: Field> Fold<'ast, T> for ArrayExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Self { + f.fold_array_expression(self) + } +} + pub trait Folder<'ast, T: Field>: Sized { fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> { fold_program(self, p) @@ -274,13 +298,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( e: ArrayExpressionInner<'ast, T>, ) -> ArrayExpressionInner<'ast, T> { match e { - ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - box f.fold_array_expression(value), - ), + ArrayExpressionInner::Block(block) => { + ArrayExpressionInner::Block(f.fold_block_expression(block)) + } ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)), ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value( exprs @@ -332,13 +352,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( e: StructExpressionInner<'ast, T>, ) -> StructExpressionInner<'ast, T> { match e { - StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - box f.fold_struct_expression(value), - ), + StructExpressionInner::Block(block) => { + StructExpressionInner::Block(f.fold_block_expression(block)) + } StructExpressionInner::Identifier(id) => StructExpressionInner::Identifier(f.fold_name(id)), StructExpressionInner::Value(exprs) => { StructExpressionInner::Value(exprs.into_iter().map(|e| f.fold_expression(e)).collect()) @@ -455,13 +471,7 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( e: BooleanExpression<'ast, T>, ) -> BooleanExpression<'ast, T> { match e { - BooleanExpression::Block(statements, box value) => BooleanExpression::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - box f.fold_boolean_expression(value), - ), + BooleanExpression::Block(block) => BooleanExpression::Block(f.fold_block_expression(block)), BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)), BooleanExpression::FieldEq(box e1, box e2) => { @@ -585,13 +595,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( e: UExpressionInner<'ast, T>, ) -> UExpressionInner<'ast, T> { match e { - UExpressionInner::Block(statements, box value) => UExpressionInner::Block( - statements - .into_iter() - .flat_map(|s| f.fold_statement(s)) - .collect(), - box f.fold_uint_expression(value), - ), + UExpressionInner::Block(block) => UExpressionInner::Block(f.fold_block_expression(block)), UExpressionInner::Value(v) => UExpressionInner::Value(v), UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)), UExpressionInner::Add(box left, box right) => { diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 518a7590..a0cbcab9 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -585,16 +585,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.inner { - StructExpressionInner::Block(ref statements, ref value) => write!( - f, - "{{{}}}", - statements - .iter() - .map(|s| s.to_string()) - .chain(std::iter::once(value.to_string())) - .collect::>() - .join("\n") - ), + StructExpressionInner::Block(ref block) => write!(f, "{}", block), StructExpressionInner::Identifier(ref var) => write!(f, "{}", var), StructExpressionInner::Value(ref values) => write!( f, @@ -822,10 +813,7 @@ impl<'ast, T> From for FieldElementExpression<'ast, T> { /// An expression of type `bool` #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub enum BooleanExpression<'ast, T> { - Block( - Vec>, - Box>, - ), + Block(BlockExpression<'ast, T, Self>), Identifier(Identifier<'ast>), Value(bool), FieldLt( @@ -982,7 +970,7 @@ impl<'ast, T> std::iter::FromIterator> for Arra #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub enum ArrayExpressionInner<'ast, T> { - Block(Vec>, Box>), + Block(BlockExpression<'ast, T, ArrayExpression<'ast, T>>), Identifier(Identifier<'ast>), Value(ArrayValue<'ast, T>), FunctionCall( @@ -1091,7 +1079,7 @@ impl<'ast, T> StructExpression<'ast, T> { #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub enum StructExpressionInner<'ast, T> { - Block(Vec>, Box>), + Block(BlockExpression<'ast, T, StructExpression<'ast, T>>), Identifier(Identifier<'ast>), Value(Vec>), FunctionCall( @@ -1305,16 +1293,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.inner { - UExpressionInner::Block(ref statements, ref value) => write!( - f, - "{{{}}}", - statements - .iter() - .map(|s| s.to_string()) - .chain(std::iter::once(value.to_string())) - .collect::>() - .join("\n") - ), + UExpressionInner::Block(ref block) => write!(f, "{}", block,), UExpressionInner::Value(ref v) => write!(f, "{}", v), UExpressionInner::Identifier(ref var) => write!(f, "{}", var), UExpressionInner::Add(ref lhs, ref rhs) => write!(f, "({} + {})", lhs, rhs), @@ -1372,16 +1351,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - BooleanExpression::Block(ref statements, ref value) => write!( - f, - "{{{}}}", - statements - .iter() - .map(|s| s.to_string()) - .chain(std::iter::once(value.to_string())) - .collect::>() - .join("\n") - ), + BooleanExpression::Block(ref block) => write!(f, "{}", block,), BooleanExpression::Identifier(ref var) => write!(f, "{}", var), BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), @@ -1439,16 +1409,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ArrayExpressionInner::Block(ref statements, ref value) => write!( - f, - "{{{}}}", - statements - .iter() - .map(|s| s.to_string()) - .chain(std::iter::once(value.to_string())) - .collect::>() - .join("\n") - ), + ArrayExpressionInner::Block(ref block) => write!(f, "{}", block,), ArrayExpressionInner::Identifier(ref var) => write!(f, "{}", var), ArrayExpressionInner::Value(ref values) => write!( f, @@ -1881,21 +1842,22 @@ impl<'ast, T: Field> Block<'ast, T> for FieldElementExpression<'ast, T> { impl<'ast, T: Field> Block<'ast, T> for BooleanExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { - BooleanExpression::Block(statements, box value) + BooleanExpression::Block(BlockExpression::new(statements, value)) } } impl<'ast, T: Field> Block<'ast, T> for UExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { let bitwidth = value.bitwidth(); - UExpressionInner::Block(statements, box value).annotate(bitwidth) + UExpressionInner::Block(BlockExpression::new(statements, value)).annotate(bitwidth) } } impl<'ast, T: Field> Block<'ast, T> for ArrayExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { let array_ty = value.ty(); - ArrayExpressionInner::Block(statements, box value).annotate(*array_ty.ty, array_ty.size) + ArrayExpressionInner::Block(BlockExpression::new(statements, value)) + .annotate(*array_ty.ty, array_ty.size) } } @@ -1903,6 +1865,6 @@ impl<'ast, T: Field> Block<'ast, T> for StructExpression<'ast, T> { fn block(statements: Vec>, value: Self) -> Self { let struct_ty = value.ty().clone(); - StructExpressionInner::Block(statements, box value).annotate(struct_ty) + StructExpressionInner::Block(BlockExpression::new(statements, value)).annotate(struct_ty) } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index 25fff6a5..5b2becc8 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -14,6 +14,30 @@ impl<'ast, T: Field> ResultFold<'ast, T> for FieldElementExpression<'ast, T> { } } +impl<'ast, T: Field> ResultFold<'ast, T> for BooleanExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Result { + f.fold_boolean_expression(self) + } +} + +impl<'ast, T: Field> ResultFold<'ast, T> for UExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Result { + f.fold_uint_expression(self) + } +} + +impl<'ast, T: Field> ResultFold<'ast, T> for ArrayExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Result { + f.fold_array_expression(self) + } +} + +impl<'ast, T: Field> ResultFold<'ast, T> for StructExpression<'ast, T> { + fn fold>(self, f: &mut F) -> Result { + f.fold_struct_expression(self) + } +} + pub trait ResultFolder<'ast, T: Field>: Sized { type Error; @@ -319,14 +343,9 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( e: ArrayExpressionInner<'ast, T>, ) -> Result, F::Error> { let e = match e { - ArrayExpressionInner::Block(statements, box value) => ArrayExpressionInner::Block( - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>() - .map(|r| r.into_iter().flatten().collect())?, - box f.fold_array_expression(value)?, - ), + ArrayExpressionInner::Block(block) => { + ArrayExpressionInner::Block(f.fold_block_expression(block)?) + } ArrayExpressionInner::Identifier(id) => ArrayExpressionInner::Identifier(f.fold_name(id)?), ArrayExpressionInner::Value(exprs) => ArrayExpressionInner::Value( exprs @@ -382,14 +401,9 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( e: StructExpressionInner<'ast, T>, ) -> Result, F::Error> { let e = match e { - StructExpressionInner::Block(statements, box value) => StructExpressionInner::Block( - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>() - .map(|r| r.into_iter().flatten().collect())?, - box f.fold_struct_expression(value)?, - ), + StructExpressionInner::Block(block) => { + StructExpressionInner::Block(f.fold_block_expression(block)?) + } StructExpressionInner::Identifier(id) => { StructExpressionInner::Identifier(f.fold_name(id)?) } @@ -536,14 +550,9 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( e: BooleanExpression<'ast, T>, ) -> Result, F::Error> { let e = match e { - BooleanExpression::Block(statements, box value) => BooleanExpression::Block( - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>() - .map(|r| r.into_iter().flatten().collect())?, - box f.fold_boolean_expression(value)?, - ), + BooleanExpression::Block(block) => { + BooleanExpression::Block(f.fold_block_expression(block)?) + } BooleanExpression::Value(v) => BooleanExpression::Value(v), BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?), BooleanExpression::FieldEq(box e1, box e2) => { @@ -671,14 +680,7 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( e: UExpressionInner<'ast, T>, ) -> Result, F::Error> { let e = match e { - UExpressionInner::Block(statements, box value) => UExpressionInner::Block( - statements - .into_iter() - .map(|s| f.fold_statement(s)) - .collect::, _>>() - .map(|r| r.into_iter().flatten().collect())?, - box f.fold_uint_expression(value)?, - ), + UExpressionInner::Block(block) => UExpressionInner::Block(f.fold_block_expression(block)?), UExpressionInner::Value(v) => UExpressionInner::Value(v), UExpressionInner::Identifier(id) => UExpressionInner::Identifier(f.fold_name(id)?), UExpressionInner::Add(box left, box right) => { diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index 3799a524..21a45b0c 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -175,7 +175,7 @@ impl<'ast, T> PartialEq for UExpression<'ast, T> { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum UExpressionInner<'ast, T> { - Block(Vec>, Box>), + Block(BlockExpression<'ast, T, UExpression<'ast, T>>), Identifier(Identifier<'ast>), Value(u128), Add(Box>, Box>), diff --git a/zokrates_core_test/tests/tests/panics/loop_bound.zok b/zokrates_core_test/tests/tests/panics/loop_bound.zok index 89f8e816..a373f9b3 100644 --- a/zokrates_core_test/tests/tests/panics/loop_bound.zok +++ b/zokrates_core_test/tests/tests/panics/loop_bound.zok @@ -1,9 +1,9 @@ -def throwing_bound(u32 x) -> u32: - assert(x == 1) +def throwing_bound(u32 x) -> u32: + assert(x == N) return 1 // Even if the bound is constant at compile time, it can throw at runtime def main(u32 x): - for u32 i in 0..throwing_bound(x) do + for u32 i in 0..throwing_bound::<1>(x) do endfor return