From e193546d8d40204b5a47abf6412bc0929fef8099 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 7 Apr 2021 18:56:55 +0200 Subject: [PATCH] check shifts are constant earlier, simplify zir --- .../compile_errors/variable_shift.zok | 2 + zokrates_core/src/flatten/mod.rs | 26 +----- .../static_analysis/flatten_complex_types.rs | 16 +++- zokrates_core/src/static_analysis/mod.rs | 4 + .../src/static_analysis/shift_checker.rs | 55 ++++++++++++ .../src/static_analysis/uint_optimizer.rs | 84 ++++++------------- zokrates_core/src/zir/folder.rs | 10 +-- zokrates_core/src/zir/uint.rs | 14 ++-- 8 files changed, 111 insertions(+), 100 deletions(-) create mode 100644 zokrates_cli/examples/compile_errors/variable_shift.zok create mode 100644 zokrates_core/src/static_analysis/shift_checker.rs diff --git a/zokrates_cli/examples/compile_errors/variable_shift.zok b/zokrates_cli/examples/compile_errors/variable_shift.zok new file mode 100644 index 00000000..ea7db352 --- /dev/null +++ b/zokrates_cli/examples/compile_errors/variable_shift.zok @@ -0,0 +1,2 @@ +def main(u32 a, u32 b) -> u32: + return a >> b \ No newline at end of file diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index 61c287aa..e3e09238 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -1263,18 +1263,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { box FlatExpression::Sub(box new_left, box new_right), )) } - UExpressionInner::LeftShift(box e, box by) => { - assert_eq!(by.bitwidth(), UBitwidth::B32); - - let by = match by.into_inner() { - UExpressionInner::Value(n) => n, - by => unimplemented!( - "Variable shifts are unimplemented, found {} << {}", - e, - by.annotate(UBitwidth::B32) - ), - }; - + UExpressionInner::LeftShift(box e, by) => { let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); @@ -1292,18 +1281,7 @@ impl<'ast, T: Field> Flattener<'ast, T> { .collect::>(), ) } - UExpressionInner::RightShift(box e, box by) => { - assert_eq!(by.bitwidth(), UBitwidth::B32); - - let by = match by.into_inner() { - UExpressionInner::Value(n) => n, - by => unimplemented!( - "Variable shifts are unimplemented, found {} >> {}", - e, - by.annotate(UBitwidth::B32) - ), - }; - + UExpressionInner::RightShift(box e, by) => { let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); diff --git a/zokrates_core/src/static_analysis/flatten_complex_types.rs b/zokrates_core/src/static_analysis/flatten_complex_types.rs index ca97b959..9a902d81 100644 --- a/zokrates_core/src/static_analysis/flatten_complex_types.rs +++ b/zokrates_core/src/static_analysis/flatten_complex_types.rs @@ -855,15 +855,23 @@ pub fn fold_uint_expression_inner<'ast, T: Field>( } typed_absy::UExpressionInner::LeftShift(box e, box by) => { let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - zir::UExpressionInner::LeftShift(box e, box by) + let by = match by.as_inner() { + typed_absy::UExpressionInner::Value(by) => by, + _ => unreachable!("static analysis should have made sure that this is constant"), + }; + + zir::UExpressionInner::LeftShift(box e, *by as u32) } typed_absy::UExpressionInner::RightShift(box e, box by) => { let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - zir::UExpressionInner::RightShift(box e, box by) + let by = match by.as_inner() { + typed_absy::UExpressionInner::Value(by) => by, + _ => unreachable!("static analysis should have made sure that this is constant"), + }; + + zir::UExpressionInner::RightShift(box e, *by as u32) } typed_absy::UExpressionInner::Not(box e) => { let e = f.fold_uint_expression(e); diff --git a/zokrates_core/src/static_analysis/mod.rs b/zokrates_core/src/static_analysis/mod.rs index 7768ba56..e30c5e54 100644 --- a/zokrates_core/src/static_analysis/mod.rs +++ b/zokrates_core/src/static_analysis/mod.rs @@ -10,6 +10,7 @@ mod flatten_complex_types; mod propagation; mod redefinition; mod reducer; +mod shift_checker; mod uint_optimizer; mod unconstrained_vars; mod variable_read_remover; @@ -20,6 +21,7 @@ use self::flatten_complex_types::Flattener; use self::propagation::Propagator; use self::redefinition::RedefinitionOptimizer; use self::reducer::reduce_program; +use self::shift_checker::ShiftChecker; use self::uint_optimizer::UintOptimizer; use self::unconstrained_vars::UnconstrainedVariableDetector; use self::variable_read_remover::VariableReadRemover; @@ -85,6 +87,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> { let r = VariableReadRemover::apply(r); // check array accesses are in bounds let r = BoundsChecker::check(r).map_err(Error::from)?; + // detect non constant shifts + let r = ShiftChecker::check(r).map_err(Error::from)?; // convert to zir, removing complex types let zir = Flattener::flatten(r); // optimize uint expressions diff --git a/zokrates_core/src/static_analysis/shift_checker.rs b/zokrates_core/src/static_analysis/shift_checker.rs new file mode 100644 index 00000000..7e44ea52 --- /dev/null +++ b/zokrates_core/src/static_analysis/shift_checker.rs @@ -0,0 +1,55 @@ +use crate::typed_absy::TypedProgram; +use crate::typed_absy::{ + result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth, + UExpressionInner, +}; +use zokrates_field::Field; +pub struct ShiftChecker; + +impl ShiftChecker { + pub fn check(p: TypedProgram) -> Result, Error> { + ShiftChecker.fold_program(p) + } +} + +pub type Error = String; + +impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker { + type Error = Error; + + fn fold_uint_expression_inner( + &mut self, + bitwidth: UBitwidth, + e: UExpressionInner<'ast, T>, + ) -> Result, Error> { + match e { + UExpressionInner::LeftShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)), + by => Err(format!( + "Cannot shift by a variable value, found `{} << {}`", + e, + by.clone().annotate(UBitwidth::B32) + )), + } + } + UExpressionInner::RightShift(box e, box by) => { + let e = self.fold_uint_expression(e)?; + let by = self.fold_uint_expression(by)?; + + match by.as_inner() { + UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)), + by => Err(format!( + "Cannot shift by a variable value, found `{} >> {}`", + e, + by.clone().annotate(UBitwidth::B32) + )), + } + } + e => fold_uint_expression_inner(self, bitwidth, e), + } + } +} diff --git a/zokrates_core/src/static_analysis/uint_optimizer.rs b/zokrates_core/src/static_analysis/uint_optimizer.rs index 10296a20..a257a489 100644 --- a/zokrates_core/src/static_analysis/uint_optimizer.rs +++ b/zokrates_core/src/static_analysis/uint_optimizer.rs @@ -314,19 +314,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { .annotate(range) .with_max(range_max) } - LeftShift(box e, box by) => { + LeftShift(box e, by) => { // reduce both terms let e = self.fold_uint_expression(e); - let by = self.fold_uint_expression(by); - let by_max: u128 = by - .metadata - .clone() - .unwrap() - .max - .to_dec_string() - .parse() - .unwrap(); let e_max: u128 = e .metadata .clone() @@ -336,20 +327,13 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { .parse() .unwrap(); - let max = T::from((e_max << by_max) & (2_u128.pow(range as u32) - 1)); + let max = T::from((e_max << by) & (2_u128.pow(range as u32) - 1)); - UExpression::left_shift(force_reduce(e), force_reduce(by)).with_max(max) + UExpression::left_shift(force_reduce(e), by).with_max(max) } - RightShift(box e, box by) => { + RightShift(box e, by) => { // reduce both terms let e = self.fold_uint_expression(e); - let by = self.fold_uint_expression(by); - - // if we don't know the amount by which we shift, the most conservative case (which leads to the biggest value) is 0 - let by_u = match by.as_inner() { - UExpressionInner::Value(by) => *by, - _ => 0, - }; let e_max: u128 = e .metadata @@ -360,11 +344,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> { .parse() .unwrap(); - let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by_u; + let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by; let max = T::from(max); - UExpression::right_shift(force_reduce(e), force_reduce(by)).with_max(max) + UExpression::right_shift(force_reduce(e), by).with_max(max) } IfElse(box condition, box consequence, box alternative) => { let condition = self.fold_boolean_expression(condition); @@ -694,30 +678,14 @@ mod tests { #[test] fn right_shift() { - // left argument in range, we reduce (no effect) and the max is the original max, as we could be shifting by 0 - uint_test!(0xff_u32, true, 2, true, right_shift, 0xff_u32); - uint_test!(2, true, 2, true, right_shift, 2_u32); - - // left argument out of range, we reduce and the max is the type max, shifted - uint_test!( - 0xffffffffffff_u128, - true, - 2, - true, - right_shift, - 0xffffffff_u32 - ); - fn right_shift_test(e_max: u128, by: u32, output_max: u32) { let left = e_with_max(e_max); - let right = UExpressionInner::Value(by as u128) - .annotate(crate::zir::types::UBitwidth::B32) - .with_max(by); + let right = by; let left_expected = force_reduce(left.clone()); - let right_expected = force_reduce(right.clone()); + let right_expected = right; assert_eq!( UintOptimizer::new() @@ -733,25 +701,25 @@ mod tests { #[test] fn left_shift() { - uint_test!(0xff_u32, true, 2, true, left_shift, 0xff_u32 << 2); - uint_test!( - 0xffffffff_u32, - true, - 2, - true, - left_shift, - 0xffffffff_u32 << 2 - ); + fn left_shift_test(e_max: u128, by: u32, output_max: u32) { + let left = e_with_max(e_max); - // left argument out of range, we reduce and the max is the type max, shifted - uint_test!( - 0xffffffffffff_u128, - true, - 2, - true, - left_shift, - 0xffffffff_u32 << 2 - ) + let right = by; + + let left_expected = force_reduce(left.clone()); + + let right_expected = right; + + assert_eq!( + UintOptimizer::new() + .fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())), + UExpression::left_shift(left_expected, right_expected).with_max(output_max) + ); + } + + left_shift_test(0xff_u128, 2, 0xff << 2); + left_shift_test(2, 2, 2 << 2); + left_shift_test(0xffffffffffff_u128, 2, 0xffffffff << 2); } #[test] diff --git a/zokrates_core/src/zir/folder.rs b/zokrates_core/src/zir/folder.rs index f82fbee1..93710896 100644 --- a/zokrates_core/src/zir/folder.rs +++ b/zokrates_core/src/zir/folder.rs @@ -308,17 +308,15 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>( UExpressionInner::Or(box left, box right) } - UExpressionInner::LeftShift(box e, box by) => { + UExpressionInner::LeftShift(box e, by) => { let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - UExpressionInner::LeftShift(box e, box by) + UExpressionInner::LeftShift(box e, by) } - UExpressionInner::RightShift(box e, box by) => { + UExpressionInner::RightShift(box e, by) => { let e = f.fold_uint_expression(e); - let by = f.fold_uint_expression(by); - UExpressionInner::RightShift(box e, box by) + UExpressionInner::RightShift(box e, by) } UExpressionInner::Not(box e) => { let e = f.fold_uint_expression(e); diff --git a/zokrates_core/src/zir/uint.rs b/zokrates_core/src/zir/uint.rs index 660e4a54..006da33f 100644 --- a/zokrates_core/src/zir/uint.rs +++ b/zokrates_core/src/zir/uint.rs @@ -57,16 +57,14 @@ impl<'ast, T: Field> UExpression<'ast, T> { UExpressionInner::And(box self, box other).annotate(bitwidth) } - pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { + pub fn left_shift(self, by: u32) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - assert_eq!(by.bitwidth(), UBitwidth::B32); - UExpressionInner::LeftShift(box self, box by).annotate(bitwidth) + UExpressionInner::LeftShift(box self, by).annotate(bitwidth) } - pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { + pub fn right_shift(self, by: u32) -> UExpression<'ast, T> { let bitwidth = self.bitwidth; - assert_eq!(by.bitwidth(), UBitwidth::B32); - UExpressionInner::RightShift(box self, box by).annotate(bitwidth) + UExpressionInner::RightShift(box self, by).annotate(bitwidth) } } @@ -170,8 +168,8 @@ pub enum UExpressionInner<'ast, T> { Xor(Box>, Box>), And(Box>, Box>), Or(Box>, Box>), - LeftShift(Box>, Box>), - RightShift(Box>, Box>), + LeftShift(Box>, u32), + RightShift(Box>, u32), Not(Box>), IfElse( Box>,