From ef4484f7efca8764af9ecc8c2617ef1fe1214646 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 31 May 2021 19:55:23 +0200 Subject: [PATCH 1/2] introduce if-else expression, implement case in propagation when consequence and alternative are equal, adjust tests --- .../tests/panics/conditional_bound_throw.zok | 10 ++++ .../conditional_bound_throw_no_isolation.json | 51 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 zokrates_core_test/tests/tests/panics/conditional_bound_throw.zok create mode 100644 zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json diff --git a/zokrates_core_test/tests/tests/panics/conditional_bound_throw.zok b/zokrates_core_test/tests/tests/panics/conditional_bound_throw.zok new file mode 100644 index 00000000..2258ee9c --- /dev/null +++ b/zokrates_core_test/tests/tests/panics/conditional_bound_throw.zok @@ -0,0 +1,10 @@ +def throwing_bound(u32 x) -> u32: + assert(x == N) + return 1 + +// this compiles: the conditional, even though it can throw, has a constant compile-time value of `1` +// However, the assertions are still checked at runtime, which leads to panics without branch isolation. +def main(u32 x): + for u32 i in 0..if x == 0 then throwing_bound::<0>(x) else throwing_bound::<1>(x) fi do + endfor + return \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json b/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json new file mode 100644 index 00000000..9d8548fa --- /dev/null +++ b/zokrates_core_test/tests/tests/panics/conditional_bound_throw_no_isolation.json @@ -0,0 +1,51 @@ +{ + "entry_point": "./tests/tests/panics/conditional_bound_throw.zok", + "curves": ["Bn128"], + "tests": [ + { + "input": { + "values": [ + "0" + ] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "0", + "right": "1" + } + } + } + }, + { + "input": { + "values": [ + "1" + ] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "1", + "right": "0" + } + } + } + }, + { + "input": { + "values": [ + "2" + ] + }, + "output": { + "Err": { + "UnsatisfiedConstraint": { + "left": "2", + "right": "0" + } + } + } + } + ] +} From 2efb8820a523c3f1852a054b247326b7d3bf5fd1 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 31 May 2021 19:59:20 +0200 Subject: [PATCH 2/2] add changelog, fmt, remove example which now compiles --- changelogs/unreleased/905-schaeff | 1 + .../block_result_propagation.zok | 10 - zokrates_core/src/semantics.rs | 16 +- .../src/static_analysis/branch_isolator.rs | 91 +----- .../static_analysis/flatten_complex_types.rs | 287 +++++++++--------- .../src/static_analysis/propagation.rs | 110 +++---- zokrates_core/src/typed_absy/folder.rs | 83 +++-- zokrates_core/src/typed_absy/integer.rs | 36 +-- zokrates_core/src/typed_absy/mod.rs | 104 +++---- zokrates_core/src/typed_absy/result_folder.rs | 83 +++-- zokrates_core/src/typed_absy/uint.rs | 6 +- 11 files changed, 366 insertions(+), 461 deletions(-) create mode 100644 changelogs/unreleased/905-schaeff delete mode 100644 zokrates_cli/examples/compile_errors/block_result_propagation.zok diff --git a/changelogs/unreleased/905-schaeff b/changelogs/unreleased/905-schaeff new file mode 100644 index 00000000..4923a545 --- /dev/null +++ b/changelogs/unreleased/905-schaeff @@ -0,0 +1 @@ +Improve propagation on if-else expressions when consequence and alternative are equal \ No newline at end of file diff --git a/zokrates_cli/examples/compile_errors/block_result_propagation.zok b/zokrates_cli/examples/compile_errors/block_result_propagation.zok deleted file mode 100644 index 82104dbc..00000000 --- a/zokrates_cli/examples/compile_errors/block_result_propagation.zok +++ /dev/null @@ -1,10 +0,0 @@ -def throwing_bound(u32 x) -> u32: - assert(x == N) - return 1 - -// this should compile: the conditional, even though it can throw, has a constant compile-time value `1` -// the value of the blocks should be propagated out, so that `if x == 0 then 1 else 1 fi` can be determined to be `1` -def main(u32 x): - for u32 i in 0..if x == 0 then throwing_bound::<0>(x) else throwing_bound::<1>(x) fi do - endfor - return diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index 8fc13972..05428046 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1988,26 +1988,22 @@ impl<'ast, T: Field> Checker<'ast, T> { TypedExpression::Boolean(condition) => { match (consequence_checked, alternative_checked) { (TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => { - Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into()) + Ok(FieldElementExpression::if_else(condition, consequence, alternative).into()) }, (TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => { - Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into()) + Ok(BooleanExpression::if_else(condition, consequence, alternative).into()) }, (TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => { - let inner_type = consequence.inner_type().clone(); - let size = consequence.size(); - Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into()) + Ok(ArrayExpression::if_else(condition, consequence, alternative).into()) }, (TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => { - let ty = consequence.ty().clone(); - Ok(StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty).into()) + Ok(StructExpression::if_else(condition, consequence, alternative).into()) }, (TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => { - let bitwidth = consequence.bitwidth(); - Ok(UExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(bitwidth).into()) + Ok(UExpression::if_else(condition, consequence, alternative).into()) }, (TypedExpression::Int(consequence), TypedExpression::Int(alternative)) => { - Ok(IntExpression::IfElse(box condition, box consequence, box alternative).into()) + Ok(IntExpression::if_else(condition, consequence, alternative).into()) }, (c, a) => Err(ErrorInner { pos: Some(pos), diff --git a/zokrates_core/src/static_analysis/branch_isolator.rs b/zokrates_core/src/static_analysis/branch_isolator.rs index 17d4d84d..c9d848d8 100644 --- a/zokrates_core/src/static_analysis/branch_isolator.rs +++ b/zokrates_core/src/static_analysis/branch_isolator.rs @@ -17,86 +17,17 @@ impl Isolator { } impl<'ast, T: Field> Folder<'ast, T> for Isolator { - fn fold_field_expression( + fn fold_if_else_expression< + E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + IfElse<'ast, T>, + >( &mut self, - e: FieldElementExpression<'ast, T>, - ) -> FieldElementExpression<'ast, T> { - match e { - FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { - FieldElementExpression::IfElse( - box self.fold_boolean_expression(condition), - box FieldElementExpression::block(vec![], consequence.fold(self)), - box FieldElementExpression::block(vec![], alternative.fold(self)), - ) - } - e => fold_field_expression(self, e), - } - } - - fn fold_boolean_expression( - &mut self, - e: BooleanExpression<'ast, T>, - ) -> BooleanExpression<'ast, T> { - match e { - BooleanExpression::IfElse(box condition, box consequence, box alternative) => { - BooleanExpression::IfElse( - box self.fold_boolean_expression(condition), - box BooleanExpression::block(vec![], consequence.fold(self)), - box BooleanExpression::block(vec![], alternative.fold(self)), - ) - } - e => fold_boolean_expression(self, e), - } - } - - fn fold_uint_expression_inner( - &mut self, - bitwidth: UBitwidth, - e: UExpressionInner<'ast, T>, - ) -> UExpressionInner<'ast, T> { - match e { - UExpressionInner::IfElse(box condition, box consequence, box alternative) => { - UExpressionInner::IfElse( - box self.fold_boolean_expression(condition), - box UExpression::block(vec![], consequence.fold(self)), - box UExpression::block(vec![], alternative.fold(self)), - ) - } - e => fold_uint_expression_inner(self, bitwidth, e), - } - } - - fn fold_array_expression_inner( - &mut self, - array_ty: &ArrayType<'ast, T>, - e: ArrayExpressionInner<'ast, T>, - ) -> ArrayExpressionInner<'ast, T> { - match e { - ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { - ArrayExpressionInner::IfElse( - box self.fold_boolean_expression(condition), - box ArrayExpression::block(vec![], consequence.fold(self)), - box ArrayExpression::block(vec![], alternative.fold(self)), - ) - } - e => fold_array_expression_inner(self, array_ty, e), - } - } - - fn fold_struct_expression_inner( - &mut self, - struct_ty: &StructType<'ast, T>, - e: StructExpressionInner<'ast, T>, - ) -> StructExpressionInner<'ast, T> { - match e { - StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { - StructExpressionInner::IfElse( - box self.fold_boolean_expression(condition), - box StructExpression::block(vec![], consequence.fold(self)), - box StructExpression::block(vec![], alternative.fold(self)), - ) - } - e => fold_struct_expression_inner(self, struct_ty, e), - } + _: &E::Ty, + e: IfElseExpression<'ast, T, E>, + ) -> IfElseOrExpression<'ast, T, E> { + IfElseOrExpression::IfElse(IfElseExpression::new( + self.fold_boolean_expression(*e.condition), + E::block(vec![], e.consequence.fold(self)), + E::block(vec![], e.alternative.fold(self)), + )) } } diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index 736cf7de..bd04ba9e 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -49,6 +49,64 @@ fn flatten_identifier_rec<'ast>( } } +trait Flatten<'ast, T: Field> { + fn flatten( + self, + f: &mut Flattener, + statements_buffer: &mut Vec>, + ) -> Vec>; +} + +impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::FieldElementExpression<'ast, T> { + fn flatten( + self, + f: &mut Flattener, + statements_buffer: &mut Vec>, + ) -> Vec> { + vec![f.fold_field_expression(statements_buffer, self).into()] + } +} + +impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::BooleanExpression<'ast, T> { + fn flatten( + self, + f: &mut Flattener, + statements_buffer: &mut Vec>, + ) -> Vec> { + vec![f.fold_boolean_expression(statements_buffer, self).into()] + } +} + +impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::UExpression<'ast, T> { + fn flatten( + self, + f: &mut Flattener, + statements_buffer: &mut Vec>, + ) -> Vec> { + vec![f.fold_uint_expression(statements_buffer, self).into()] + } +} + +impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::ArrayExpression<'ast, T> { + fn flatten( + self, + f: &mut Flattener, + statements_buffer: &mut Vec>, + ) -> Vec> { + f.fold_array_expression(statements_buffer, self) + } +} + +impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::StructExpression<'ast, T> { + fn flatten( + self, + f: &mut Flattener, + statements_buffer: &mut Vec>, + ) -> Vec> { + f.fold_struct_expression(statements_buffer, self) + } +} + impl<'ast, T: Field> Flattener { pub fn flatten(p: typed_absy::TypedProgram) -> zir::ZirProgram { let mut f = Flattener::default(); @@ -223,7 +281,15 @@ impl<'ast, T: Field> Flattener { } } - pub fn fold_member_expression( + fn fold_if_else_expression>( + &mut self, + statements_buffer: &mut Vec>, + c: typed_absy::IfElseExpression<'ast, T, E>, + ) -> Vec> { + fold_if_else_expression(self, statements_buffer, c) + } + + fn fold_member_expression( &mut self, statements_buffer: &mut Vec>, m: typed_absy::MemberExpression<'ast, T, E>, @@ -231,7 +297,7 @@ impl<'ast, T: Field> Flattener { fold_member_expression(self, statements_buffer, m) } - pub fn fold_select_expression( + fn fold_select_expression( &mut self, statements_buffer: &mut Vec>, select: typed_absy::SelectExpression<'ast, T, E>, @@ -289,7 +355,7 @@ impl<'ast, T: Field> Flattener { } } -pub fn fold_statement<'ast, T: Field>( +fn fold_statement<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, s: typed_absy::TypedStatement<'ast, T>, @@ -334,7 +400,7 @@ pub fn fold_statement<'ast, T: Field>( statements_buffer.extend(res); } -pub fn fold_array_expression_inner<'ast, T: Field>( +fn fold_array_expression_inner<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, ty: &typed_absy::types::ConcreteType, @@ -376,44 +442,8 @@ pub fn fold_array_expression_inner<'ast, T: Field>( exprs } typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(), - typed_absy::ArrayExpressionInner::IfElse( - box condition, - box consequence, - box alternative, - ) => { - let mut consequence_statements = vec![]; - let mut alternative_statements = vec![]; - - let condition = f.fold_boolean_expression(statements_buffer, condition); - let consequence = f.fold_array_expression(&mut consequence_statements, consequence); - let alternative = f.fold_array_expression(&mut alternative_statements, alternative); - - assert_eq!(consequence.len(), alternative.len()); - - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); - - use crate::zir::IfElse; - - consequence - .into_iter() - .zip(alternative.into_iter()) - .map(|(c, a)| match (c, a) { - (zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => { - zir::FieldElementExpression::if_else(condition.clone(), c, a).into() - } - (zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => { - zir::BooleanExpression::if_else(condition.clone(), c, a).into() - } - (zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => { - zir::UExpression::if_else(condition.clone(), c, a).into() - } - _ => unreachable!(), - }) - .collect() + typed_absy::ArrayExpressionInner::IfElse(c) => { + f.fold_if_else_expression(statements_buffer, c) } typed_absy::ArrayExpressionInner::Member(m) => { f.fold_member_expression(statements_buffer, m) @@ -452,7 +482,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>( } } -pub fn fold_struct_expression_inner<'ast, T: Field>( +fn fold_struct_expression_inner<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, ty: &typed_absy::types::ConcreteStructType, @@ -487,44 +517,8 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( .flat_map(|e| f.fold_expression(statements_buffer, e)) .collect(), typed_absy::StructExpressionInner::FunctionCall(..) => unreachable!(), - typed_absy::StructExpressionInner::IfElse( - box condition, - box consequence, - box alternative, - ) => { - let mut consequence_statements = vec![]; - let mut alternative_statements = vec![]; - - let condition = f.fold_boolean_expression(statements_buffer, condition); - let consequence = f.fold_struct_expression(&mut consequence_statements, consequence); - let alternative = f.fold_struct_expression(&mut alternative_statements, alternative); - - assert_eq!(consequence.len(), alternative.len()); - - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); - - use zir::IfElse; - - consequence - .into_iter() - .zip(alternative.into_iter()) - .map(|(c, a)| match (c, a) { - (zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => { - zir::FieldElementExpression::if_else(condition.clone(), c, a).into() - } - (zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => { - zir::BooleanExpression::if_else(condition.clone(), c, a).into() - } - (zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => { - zir::UExpression::if_else(condition.clone(), c, a).into() - } - _ => unreachable!(), - }) - .collect() + typed_absy::StructExpressionInner::IfElse(c) => { + f.fold_if_else_expression(statements_buffer, c) } typed_absy::StructExpressionInner::Member(m) => { f.fold_member_expression(statements_buffer, m) @@ -535,7 +529,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field>( } } -pub fn fold_member_expression<'ast, T: Field, E>( +fn fold_member_expression<'ast, T: Field, E>( f: &mut Flattener, statements_buffer: &mut Vec>, m: typed_absy::MemberExpression<'ast, T, E>, @@ -571,7 +565,7 @@ pub fn fold_member_expression<'ast, T: Field, E>( s[offset..offset + size].to_vec() } -pub fn fold_select_expression<'ast, T: Field, E>( +fn fold_select_expression<'ast, T: Field, E>( f: &mut Flattener, statements_buffer: &mut Vec>, select: typed_absy::SelectExpression<'ast, T, E>, @@ -593,7 +587,47 @@ pub fn fold_select_expression<'ast, T: Field, E>( } } -pub fn fold_field_expression<'ast, T: Field>( +fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>( + f: &mut Flattener, + statements_buffer: &mut Vec>, + c: typed_absy::IfElseExpression<'ast, T, E>, +) -> Vec> { + let mut consequence_statements = vec![]; + let mut alternative_statements = vec![]; + + let condition = f.fold_boolean_expression(statements_buffer, *c.condition); + let consequence = c.consequence.flatten(f, &mut consequence_statements); + let alternative = c.alternative.flatten(f, &mut alternative_statements); + + assert_eq!(consequence.len(), alternative.len()); + + statements_buffer.push(zir::ZirStatement::IfElse( + condition.clone(), + consequence_statements, + alternative_statements, + )); + + use crate::zir::IfElse; + + consequence + .into_iter() + .zip(alternative.into_iter()) + .map(|(c, a)| match (c, a) { + (zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => { + zir::FieldElementExpression::if_else(condition.clone(), c, a).into() + } + (zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => { + zir::BooleanExpression::if_else(condition.clone(), c, a).into() + } + (zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => { + zir::UExpression::if_else(condition.clone(), c, a).into() + } + _ => unreachable!(), + }) + .collect() +} + +fn fold_field_expression<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, e: typed_absy::FieldElementExpression<'ast, T>, @@ -647,26 +681,12 @@ pub fn fold_field_expression<'ast, T: Field>( typed_absy::FieldElementExpression::Pos(box e) => { f.fold_field_expression(statements_buffer, e) } - typed_absy::FieldElementExpression::IfElse( - box condition, - box consequence, - box alternative, - ) => { - let mut consequence_statements = vec![]; - let mut alternative_statements = vec![]; - - let condition = f.fold_boolean_expression(statements_buffer, condition); - let consequence = f.fold_field_expression(&mut consequence_statements, consequence); - let alternative = f.fold_field_expression(&mut alternative_statements, alternative); - - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); - - zir::FieldElementExpression::IfElse(box condition, box consequence, box alternative) - } + typed_absy::FieldElementExpression::IfElse(c) => f + .fold_if_else_expression(statements_buffer, c) + .pop() + .unwrap() + .try_into() + .unwrap(), typed_absy::FieldElementExpression::FunctionCall(..) => unreachable!(""), typed_absy::FieldElementExpression::Select(select) => f .fold_select_expression(statements_buffer, select) @@ -690,7 +710,7 @@ pub fn fold_field_expression<'ast, T: Field>( } } -pub fn fold_boolean_expression<'ast, T: Field>( +fn fold_boolean_expression<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, e: typed_absy::BooleanExpression<'ast, T>, @@ -836,22 +856,12 @@ pub fn fold_boolean_expression<'ast, T: Field>( let e = f.fold_boolean_expression(statements_buffer, e); zir::BooleanExpression::Not(box e) } - typed_absy::BooleanExpression::IfElse(box condition, box consequence, box alternative) => { - let mut consequence_statements = vec![]; - let mut alternative_statements = vec![]; - - let condition = f.fold_boolean_expression(statements_buffer, condition); - let consequence = f.fold_boolean_expression(&mut consequence_statements, consequence); - let alternative = f.fold_boolean_expression(&mut alternative_statements, alternative); - - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); - - zir::BooleanExpression::IfElse(box condition, box consequence, box alternative) - } + typed_absy::BooleanExpression::IfElse(c) => f + .fold_if_else_expression(statements_buffer, c) + .pop() + .unwrap() + .try_into() + .unwrap(), typed_absy::BooleanExpression::FunctionCall(..) => unreachable!(), typed_absy::BooleanExpression::Select(select) => f .fold_select_expression(statements_buffer, select) @@ -868,7 +878,7 @@ pub fn fold_boolean_expression<'ast, T: Field>( } } -pub fn fold_uint_expression<'ast, T: Field>( +fn fold_uint_expression<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, e: typed_absy::UExpression<'ast, T>, @@ -877,7 +887,7 @@ pub fn fold_uint_expression<'ast, T: Field>( .annotate(e.bitwidth.to_usize()) } -pub fn fold_uint_expression_inner<'ast, T: Field>( +fn fold_uint_expression_inner<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, bitwidth: UBitwidth, @@ -1008,26 +1018,17 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( ) .unwrap() .into_inner(), - typed_absy::UExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let mut consequence_statements = vec![]; - let mut alternative_statements = vec![]; - - let condition = f.fold_boolean_expression(statements_buffer, condition); - let consequence = f.fold_uint_expression(&mut consequence_statements, consequence); - let alternative = f.fold_uint_expression(&mut alternative_statements, alternative); - - statements_buffer.push(zir::ZirStatement::IfElse( - condition.clone(), - consequence_statements, - alternative_statements, - )); - - zir::UExpressionInner::IfElse(box condition, box consequence, box alternative) - } + typed_absy::UExpressionInner::IfElse(c) => zir::UExpression::try_from( + f.fold_if_else_expression(statements_buffer, c) + .pop() + .unwrap(), + ) + .unwrap() + .into_inner(), } } -pub fn fold_function<'ast, T: Field>( +fn fold_function<'ast, T: Field>( f: &mut Flattener, fun: typed_absy::TypedFunction<'ast, T>, ) -> zir::ZirFunction<'ast, T> { @@ -1052,7 +1053,7 @@ pub fn fold_function<'ast, T: Field>( } } -pub fn fold_array_expression<'ast, T: Field>( +fn fold_array_expression<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, e: typed_absy::ArrayExpression<'ast, T>, @@ -1069,7 +1070,7 @@ pub fn fold_array_expression<'ast, T: Field>( ) } -pub fn fold_struct_expression<'ast, T: Field>( +fn fold_struct_expression<'ast, T: Field>( f: &mut Flattener, statements_buffer: &mut Vec>, e: typed_absy::StructExpression<'ast, T>, @@ -1081,7 +1082,7 @@ pub fn fold_struct_expression<'ast, T: Field>( ) } -pub fn fold_program<'ast, T: Field>( +fn fold_program<'ast, T: Field>( f: &mut Flattener, mut p: typed_absy::TypedProgram<'ast, T>, ) -> zir::ZirProgram<'ast, T> { diff --git a/zokrates_core/src/static_analysis/propagation.rs b/zokrates_core/src/static_analysis/propagation.rs index 42d22851..2438971c 100644 --- a/zokrates_core/src/static_analysis/propagation.rs +++ b/zokrates_core/src/static_analysis/propagation.rs @@ -290,6 +290,35 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { fold_function(self, f) } + fn fold_if_else_expression< + E: Expr<'ast, T> + IfElse<'ast, T> + PartialEq + ResultFold<'ast, T>, + >( + &mut self, + _: &E::Ty, + e: IfElseExpression<'ast, T, E>, + ) -> Result, Self::Error> { + Ok( + match ( + self.fold_boolean_expression(*e.condition)?, + e.consequence.fold(self)?, + e.alternative.fold(self)?, + ) { + (BooleanExpression::Value(true), consequence, _) => { + IfElseOrExpression::Expression(consequence.into_inner()) + } + (BooleanExpression::Value(false), _, alternative) => { + IfElseOrExpression::Expression(alternative.into_inner()) + } + (_, consequence, alternative) if consequence == alternative => { + IfElseOrExpression::Expression(consequence.into_inner()) + } + (condition, consequence, alternative) => IfElseOrExpression::IfElse( + IfElseExpression::new(condition, consequence, alternative), + ), + }, + ) + } + fn fold_statement( &mut self, s: TypedStatement<'ast, T>, @@ -913,19 +942,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { box e2.annotate(bitwidth), )), }, - UExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_uint_expression(consequence)?; - let alternative = self.fold_uint_expression(alternative)?; - match self.fold_boolean_expression(condition)? { - BooleanExpression::Value(true) => Ok(consequence.into_inner()), - BooleanExpression::Value(false) => Ok(alternative.into_inner()), - c => Ok(UExpressionInner::IfElse( - box c, - box consequence, - box alternative, - )), - } - } UExpressionInner::Not(box e) => { let e = self.fold_uint_expression(e)?.into_inner(); match e { @@ -1035,19 +1051,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { )), } } - FieldElementExpression::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_field_expression(consequence)?; - let alternative = self.fold_field_expression(alternative)?; - match self.fold_boolean_expression(condition)? { - BooleanExpression::Value(true) => Ok(consequence), - BooleanExpression::Value(false) => Ok(alternative), - c => Ok(FieldElementExpression::IfElse( - box c, - box consequence, - box alternative, - )), - } - } e => fold_field_expression(self, e), } } @@ -1167,19 +1170,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }, None => Ok(ArrayExpressionInner::Identifier(id)), }, - ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_array_expression(consequence)?; - let alternative = self.fold_array_expression(alternative)?; - match self.fold_boolean_expression(condition)? { - BooleanExpression::Value(true) => Ok(consequence.into_inner()), - BooleanExpression::Value(false) => Ok(alternative.into_inner()), - c => Ok(ArrayExpressionInner::IfElse( - box c, - box consequence, - box alternative, - )), - } - } e => fold_array_expression_inner(self, ty, e), } } @@ -1197,19 +1187,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { }, None => Ok(StructExpressionInner::Identifier(id)), }, - StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_struct_expression(consequence)?; - let alternative = self.fold_struct_expression(alternative)?; - match self.fold_boolean_expression(condition)? { - BooleanExpression::Value(true) => Ok(consequence.into_inner()), - BooleanExpression::Value(false) => Ok(alternative.into_inner()), - c => Ok(StructExpressionInner::IfElse( - box c, - box consequence, - box alternative, - )), - } - } StructExpressionInner::Value(v) => { let v = v.into_iter().zip(ty.iter()).map(|(v, member)| match self.fold_expression(v) { @@ -1433,19 +1410,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { e => Ok(BooleanExpression::Not(box e)), } } - BooleanExpression::IfElse(box condition, box consequence, box alternative) => { - let consequence = self.fold_boolean_expression(consequence)?; - let alternative = self.fold_boolean_expression(alternative)?; - match self.fold_boolean_expression(condition)? { - BooleanExpression::Value(true) => Ok(consequence), - BooleanExpression::Value(false) => Ok(alternative), - c => Ok(BooleanExpression::IfElse( - box c, - box consequence, - box alternative, - )), - } - } e => fold_boolean_expression(self, e), } } @@ -1531,10 +1495,10 @@ mod tests { #[test] fn if_else_true() { - let e = FieldElementExpression::IfElse( - box BooleanExpression::Value(true), - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + let e = FieldElementExpression::if_else( + BooleanExpression::Value(true), + FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::Number(Bn128Field::from(3)), ); assert_eq!( @@ -1545,10 +1509,10 @@ mod tests { #[test] fn if_else_false() { - let e = FieldElementExpression::IfElse( - box BooleanExpression::Value(false), - box FieldElementExpression::Number(Bn128Field::from(2)), - box FieldElementExpression::Number(Bn128Field::from(3)), + let e = FieldElementExpression::if_else( + BooleanExpression::Value(false), + FieldElementExpression::Number(Bn128Field::from(2)), + FieldElementExpression::Number(Bn128Field::from(3)), ); assert_eq!( diff --git a/zokrates_core/src/typed_absy/folder.rs b/zokrates_core/src/typed_absy/folder.rs index c7ff3f95..9aa27358 100644 --- a/zokrates_core/src/typed_absy/folder.rs +++ b/zokrates_core/src/typed_absy/folder.rs @@ -193,6 +193,20 @@ pub trait Folder<'ast, T: Field>: Sized { fold_block_expression(self, block) } + fn fold_if_else_expression< + E: Expr<'ast, T> + + Fold<'ast, T> + + Block<'ast, T> + + IfElse<'ast, T> + + From>, + >( + &mut self, + ty: &E::Ty, + e: IfElseExpression<'ast, T, E>, + ) -> IfElseOrExpression<'ast, T, E> { + fold_if_else_expression(self, ty, e) + } + fn fold_member_expression< E: Expr<'ast, T> + Member<'ast, T> + From>, >( @@ -375,13 +389,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { - ArrayExpressionInner::IfElse( - box f.fold_boolean_expression(condition), - box f.fold_array_expression(consequence), - box f.fold_array_expression(alternative), - ) - } + ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) { + IfElseOrExpression::IfElse(s) => ArrayExpressionInner::IfElse(s), + IfElseOrExpression::Expression(u) => u, + }, ArrayExpressionInner::Select(select) => match f.fold_select_expression(ty, select) { SelectOrExpression::Select(s) => ArrayExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, @@ -425,13 +436,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { - StructExpressionInner::IfElse( - box f.fold_boolean_expression(condition), - box f.fold_struct_expression(consequence), - box f.fold_struct_expression(alternative), - ) - } + StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) { + IfElseOrExpression::IfElse(s) => StructExpressionInner::IfElse(s), + IfElseOrExpression::Expression(u) => u, + }, StructExpressionInner::Select(select) => match f.fold_select_expression(ty, select) { SelectOrExpression::Select(s) => StructExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, @@ -490,11 +498,11 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( FieldElementExpression::Pos(box e) } - FieldElementExpression::IfElse(box cond, box cons, box alt) => { - let cond = f.fold_boolean_expression(cond); - let cons = f.fold_field_expression(cons); - let alt = f.fold_field_expression(alt); - FieldElementExpression::IfElse(box cond, box cons, box alt) + FieldElementExpression::IfElse(c) => { + match f.fold_if_else_expression(&Type::FieldElement, c) { + IfElseOrExpression::IfElse(s) => FieldElementExpression::IfElse(s), + IfElseOrExpression::Expression(u) => u, + } } FieldElementExpression::FunctionCall(function_call) => { match f.fold_function_call_expression(&Type::FieldElement, function_call) { @@ -518,6 +526,23 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>( } } +pub fn fold_if_else_expression< + 'ast, + T: Field, + E: Expr<'ast, T> + Fold<'ast, T> + IfElse<'ast, T> + From>, + F: Folder<'ast, T>, +>( + f: &mut F, + _: &E::Ty, + e: IfElseExpression<'ast, T, E>, +) -> IfElseOrExpression<'ast, T, E> { + IfElseOrExpression::IfElse(IfElseExpression::new( + f.fold_boolean_expression(*e.condition), + e.consequence.fold(f), + e.alternative.fold(f), + )) +} + pub fn fold_member_expression< 'ast, T: Field, @@ -652,12 +677,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - BooleanExpression::IfElse(box cond, box cons, box alt) => { - let cond = f.fold_boolean_expression(cond); - let cons = f.fold_boolean_expression(cons); - let alt = f.fold_boolean_expression(alt); - BooleanExpression::IfElse(box cond, box cons, box alt) - } + BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c) { + IfElseOrExpression::IfElse(s) => BooleanExpression::IfElse(s), + IfElseOrExpression::Expression(u) => u, + }, BooleanExpression::Select(select) => match f.fold_select_expression(&Type::Boolean, select) { SelectOrExpression::Select(s) => BooleanExpression::Select(s), @@ -782,12 +805,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::IfElse(box cond, box cons, box alt) => { - let cond = f.fold_boolean_expression(cond); - let cons = f.fold_uint_expression(cons); - let alt = f.fold_uint_expression(alt); - UExpressionInner::IfElse(box cond, box cons, box alt) - } + UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c) { + IfElseOrExpression::IfElse(s) => UExpressionInner::IfElse(s), + IfElseOrExpression::Expression(u) => u, + }, UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m) { MemberOrExpression::Member(m) => UExpressionInner::Member(m), MemberOrExpression::Expression(u) => u, diff --git a/zokrates_core/src/typed_absy/integer.rs b/zokrates_core/src/typed_absy/integer.rs index 955ad45a..851acf7a 100644 --- a/zokrates_core/src/typed_absy/integer.rs +++ b/zokrates_core/src/typed_absy/integer.rs @@ -2,8 +2,8 @@ use crate::typed_absy::types::{ArrayType, Type}; use crate::typed_absy::UBitwidth; use crate::typed_absy::{ ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse, - Select, SelectExpression, StructExpression, Typed, TypedExpression, TypedExpressionOrSpread, - TypedSpread, UExpression, UExpressionInner, + IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression, + TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner, }; use num_bigint::BigUint; use std::convert::TryFrom; @@ -142,11 +142,7 @@ pub enum IntExpression<'ast, T> { Div(Box>, Box>), Rem(Box>, Box>), Pow(Box>, Box>), - IfElse( - Box>, - Box>, - Box>, - ), + IfElse(IfElseExpression<'ast, T, IntExpression<'ast, T>>), Select(SelectExpression<'ast, T, IntExpression<'ast, T>>), Xor(Box>, Box>), And(Box>, Box>), @@ -261,11 +257,7 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> { IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by), IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by), IntExpression::Not(ref e) => write!(f, "!{}", e), - IntExpression::IfElse(ref condition, ref consequent, ref alternative) => write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ), + IntExpression::IfElse(ref c) => write!(f, "{}", c), } } } @@ -315,13 +307,11 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> { )), IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)), IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)), - IntExpression::IfElse(box condition, box consequence, box alternative) => { - Ok(Self::IfElse( - box condition, - box Self::try_from_int(consequence)?, - box Self::try_from_int(alternative)?, - )) - } + IntExpression::IfElse(c) => Ok(Self::IfElse(IfElseExpression::new( + *c.condition, + Self::try_from_int(*c.consequence)?, + Self::try_from_int(*c.alternative)?, + ))), IntExpression::Select(select) => { let array = *select.array; let index = *select.index; @@ -430,10 +420,10 @@ impl<'ast, T: Field> UExpression<'ast, T> { Self::try_from_int(e1, bitwidth)?, e2, )), - IfElse(box condition, box consequence, box alternative) => Ok(UExpression::if_else( - condition, - Self::try_from_int(consequence, bitwidth)?, - Self::try_from_int(alternative, bitwidth)?, + IfElse(c) => Ok(UExpression::if_else( + *c.condition, + Self::try_from_int(*c.consequence, bitwidth)?, + Self::try_from_int(*c.alternative, bitwidth)?, )), Select(select) => { let array = *select.array; diff --git a/zokrates_core/src/typed_absy/mod.rs b/zokrates_core/src/typed_absy/mod.rs index 7e8fefb1..0b08413b 100644 --- a/zokrates_core/src/typed_absy/mod.rs +++ b/zokrates_core/src/typed_absy/mod.rs @@ -641,13 +641,7 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> { StructExpressionInner::FunctionCall(ref function_call) => { write!(f, "{}", function_call) } - StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => { - write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ) - } + StructExpressionInner::IfElse(ref c) => write!(f, "{}", c), StructExpressionInner::Member(ref m) => write!(f, "{}", m), StructExpressionInner::Select(ref select) => write!(f, "{}", select), } @@ -792,6 +786,33 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> { } } +#[derive(Clone, PartialEq, Debug, Hash, Eq)] +pub struct IfElseExpression<'ast, T, E> { + pub condition: Box>, + pub consequence: Box, + pub alternative: Box, +} + +impl<'ast, T, E> IfElseExpression<'ast, T, E> { + pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self { + IfElseExpression { + condition: box condition, + consequence: box consequence, + alternative: box alternative, + } + } +} + +impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<'ast, T, E> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "if {} then {} else {} fi", + self.condition, self.consequence, self.alternative + ) + } +} + #[derive(Clone, PartialEq, Debug, Hash, Eq)] pub struct FunctionCallExpression<'ast, T, E> { pub function_key: DeclarationFunctionKey<'ast>, @@ -870,11 +891,7 @@ pub enum FieldElementExpression<'ast, T> { Box>, Box>, ), - IfElse( - Box>, - Box>, - Box>, - ), + IfElse(IfElseExpression<'ast, T, Self>), Neg(Box>), Pos(Box>), FunctionCall(FunctionCallExpression<'ast, T, Self>), @@ -974,11 +991,7 @@ pub enum BooleanExpression<'ast, T> { Box>, ), Not(Box>), - IfElse( - Box>, - Box>, - Box>, - ), + IfElse(IfElseExpression<'ast, T, Self>), Member(MemberExpression<'ast, T, Self>), FunctionCall(FunctionCallExpression<'ast, T, Self>), Select(SelectExpression<'ast, T, Self>), @@ -1085,11 +1098,7 @@ pub enum ArrayExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(ArrayValue<'ast, T>), FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>), - IfElse( - Box>, - Box>, - Box>, - ), + IfElse(IfElseExpression<'ast, T, ArrayExpression<'ast, T>>), Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>), Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>), Slice( @@ -1190,11 +1199,7 @@ pub enum StructExpressionInner<'ast, T> { Identifier(Identifier<'ast>), Value(Vec>), FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>), - IfElse( - Box>, - Box>, - Box>, - ), + IfElse(IfElseExpression<'ast, T, StructExpression<'ast, T>>), Member(MemberExpression<'ast, T, StructExpression<'ast, T>>), Select(SelectExpression<'ast, T, StructExpression<'ast, T>>), } @@ -1333,13 +1338,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> { FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs), FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e), FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e), - FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => { - write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ) - } + FieldElementExpression::IfElse(ref c) => write!(f, "{}", c), FieldElementExpression::FunctionCall(ref function_call) => { write!(f, "{}", function_call) } @@ -1373,11 +1372,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> { UExpressionInner::Pos(ref e) => write!(f, "(+{})", e), UExpressionInner::Select(ref select) => write!(f, "{}", select), UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call), - UExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ), + UExpressionInner::IfElse(ref c) => write!(f, "{}", c), UExpressionInner::Member(ref m) => write!(f, "{}", m), } } @@ -1406,11 +1401,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> { BooleanExpression::Not(ref exp) => write!(f, "!{}", exp), BooleanExpression::Value(b) => write!(f, "{}", b), BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call), - BooleanExpression::IfElse(ref condition, ref consequent, ref alternative) => write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ), + BooleanExpression::IfElse(ref c) => write!(f, "{}", c), BooleanExpression::Member(ref m) => write!(f, "{}", m), BooleanExpression::Select(ref select) => write!(f, "{}", select), } @@ -1432,11 +1423,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> { .join(", ") ), ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call), - ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!( - f, - "if {} then {} else {} fi", - condition, consequent, alternative - ), + ArrayExpressionInner::IfElse(ref c) => write!(f, "{}", c), ArrayExpressionInner::Member(ref m) => write!(f, "{}", m), ArrayExpressionInner::Select(ref select) => write!(f, "{}", select), ArrayExpressionInner::Slice(ref a, ref from, ref to) => { @@ -1616,6 +1603,11 @@ pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> { Expression(E::Inner), } +pub enum IfElseOrExpression<'ast, T, E: Expr<'ast, T>> { + IfElse(IfElseExpression<'ast, T, E>), + Expression(E::Inner), +} + pub trait IfElse<'ast, T> { fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self) -> Self; @@ -1627,7 +1619,7 @@ impl<'ast, T> IfElse<'ast, T> for FieldElementExpression<'ast, T> { consequence: Self, alternative: Self, ) -> Self { - FieldElementExpression::IfElse(box condition, box consequence, box alternative) + FieldElementExpression::IfElse(IfElseExpression::new(condition, consequence, alternative)) } } @@ -1637,7 +1629,7 @@ impl<'ast, T> IfElse<'ast, T> for IntExpression<'ast, T> { consequence: Self, alternative: Self, ) -> Self { - IntExpression::IfElse(box condition, box consequence, box alternative) + IntExpression::IfElse(IfElseExpression::new(condition, consequence, alternative)) } } @@ -1647,7 +1639,7 @@ impl<'ast, T> IfElse<'ast, T> for BooleanExpression<'ast, T> { consequence: Self, alternative: Self, ) -> Self { - BooleanExpression::IfElse(box condition, box consequence, box alternative) + BooleanExpression::IfElse(IfElseExpression::new(condition, consequence, alternative)) } } @@ -1659,7 +1651,8 @@ impl<'ast, T> IfElse<'ast, T> for UExpression<'ast, T> { ) -> Self { let bitwidth = consequence.bitwidth; - UExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(bitwidth) + UExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative)) + .annotate(bitwidth) } } @@ -1671,7 +1664,7 @@ impl<'ast, T: Clone> IfElse<'ast, T> for ArrayExpression<'ast, T> { ) -> Self { let ty = consequence.inner_type().clone(); let size = consequence.size(); - ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) + ArrayExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative)) .annotate(ty, size) } } @@ -1683,7 +1676,8 @@ impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> { alternative: Self, ) -> Self { let ty = consequence.ty().clone(); - StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty) + StructExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative)) + .annotate(ty) } } diff --git a/zokrates_core/src/typed_absy/result_folder.rs b/zokrates_core/src/typed_absy/result_folder.rs index ad712694..bef3daaf 100644 --- a/zokrates_core/src/typed_absy/result_folder.rs +++ b/zokrates_core/src/typed_absy/result_folder.rs @@ -142,6 +142,16 @@ pub trait ResultFolder<'ast, T: Field>: Sized { fold_types(self, tys) } + fn fold_if_else_expression< + E: Expr<'ast, T> + PartialEq + IfElse<'ast, T> + ResultFold<'ast, T>, + >( + &mut self, + ty: &E::Ty, + e: IfElseExpression<'ast, T, E>, + ) -> Result, Self::Error> { + fold_if_else_expression(self, ty, e) + } + fn fold_block_expression>( &mut self, block: BlockExpression<'ast, T, E>, @@ -415,13 +425,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => { - ArrayExpressionInner::IfElse( - box f.fold_boolean_expression(condition)?, - box f.fold_array_expression(consequence)?, - box f.fold_array_expression(alternative)?, - ) - } + ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? { + IfElseOrExpression::IfElse(c) => ArrayExpressionInner::IfElse(c), + IfElseOrExpression::Expression(u) => u, + }, ArrayExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? { MemberOrExpression::Member(m) => ArrayExpressionInner::Member(m), MemberOrExpression::Expression(u) => u, @@ -469,13 +476,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - StructExpressionInner::IfElse(box condition, box consequence, box alternative) => { - StructExpressionInner::IfElse( - box f.fold_boolean_expression(condition)?, - box f.fold_struct_expression(consequence)?, - box f.fold_struct_expression(alternative)?, - ) - } + StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? { + IfElseOrExpression::IfElse(c) => StructExpressionInner::IfElse(c), + IfElseOrExpression::Expression(u) => u, + }, StructExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? { MemberOrExpression::Member(m) => StructExpressionInner::Member(m), MemberOrExpression::Expression(u) => u, @@ -535,11 +539,11 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( FieldElementExpression::Pos(box e) } - FieldElementExpression::IfElse(box cond, box cons, box alt) => { - let cond = f.fold_boolean_expression(cond)?; - let cons = f.fold_field_expression(cons)?; - let alt = f.fold_field_expression(alt)?; - FieldElementExpression::IfElse(box cond, box cons, box alt) + FieldElementExpression::IfElse(c) => { + match f.fold_if_else_expression(&Type::FieldElement, c)? { + IfElseOrExpression::IfElse(c) => FieldElementExpression::IfElse(c), + IfElseOrExpression::Expression(u) => u, + } } FieldElementExpression::FunctionCall(function_call) => { match f.fold_function_call_expression(&Type::FieldElement, function_call)? { @@ -589,6 +593,27 @@ pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFo }) } +pub fn fold_if_else_expression< + 'ast, + T: Field, + E: Expr<'ast, T> + + IfElse<'ast, T> + + PartialEq + + ResultFold<'ast, T> + + From>, + F: ResultFolder<'ast, T>, +>( + f: &mut F, + _: &E::Ty, + e: IfElseExpression<'ast, T, E>, +) -> Result, F::Error> { + Ok(IfElseOrExpression::IfElse(IfElseExpression::new( + f.fold_boolean_expression(*e.condition)?, + e.consequence.fold(f)?, + e.alternative.fold(f)?, + ))) +} + pub fn fold_member_expression< 'ast, T: Field, @@ -739,12 +764,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>( FunctionCallOrExpression::Expression(u) => u, } } - BooleanExpression::IfElse(box cond, box cons, box alt) => { - let cond = f.fold_boolean_expression(cond)?; - let cons = f.fold_boolean_expression(cons)?; - let alt = f.fold_boolean_expression(alt)?; - BooleanExpression::IfElse(box cond, box cons, box alt) - } + BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c)? { + IfElseOrExpression::IfElse(c) => BooleanExpression::IfElse(c), + IfElseOrExpression::Expression(u) => u, + }, BooleanExpression::Select(select) => { match f.fold_select_expression(&Type::Boolean, select)? { SelectOrExpression::Select(s) => BooleanExpression::Select(s), @@ -869,12 +892,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>( SelectOrExpression::Select(s) => UExpressionInner::Select(s), SelectOrExpression::Expression(u) => u, }, - UExpressionInner::IfElse(box cond, box cons, box alt) => { - let cond = f.fold_boolean_expression(cond)?; - let cons = f.fold_uint_expression(cons)?; - let alt = f.fold_uint_expression(alt)?; - UExpressionInner::IfElse(box cond, box cons, box alt) - } + UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c)? { + IfElseOrExpression::IfElse(c) => UExpressionInner::IfElse(c), + IfElseOrExpression::Expression(u) => u, + }, UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m)? { MemberOrExpression::Member(m) => UExpressionInner::Member(m), MemberOrExpression::Expression(u) => u, diff --git a/zokrates_core/src/typed_absy/uint.rs b/zokrates_core/src/typed_absy/uint.rs index d576114d..4620dd7c 100644 --- a/zokrates_core/src/typed_absy/uint.rs +++ b/zokrates_core/src/typed_absy/uint.rs @@ -193,11 +193,7 @@ pub enum UExpressionInner<'ast, T> { FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>), LeftShift(Box>, Box>), RightShift(Box>, Box>), - IfElse( - Box>, - Box>, - Box>, - ), + IfElse(IfElseExpression<'ast, T, UExpression<'ast, T>>), Member(MemberExpression<'ast, T, UExpression<'ast, T>>), Select(SelectExpression<'ast, T, UExpression<'ast, T>>), }