diff --git a/zokrates_analysis/src/zir_propagation.rs b/zokrates_analysis/src/zir_propagation.rs index 8582b33b..1a026b17 100644 --- a/zokrates_analysis/src/zir_propagation.rs +++ b/zokrates_analysis/src/zir_propagation.rs @@ -1,6 +1,8 @@ +use num::traits::Pow; +use num_bigint::BigUint; use std::collections::HashMap; use std::fmt; -use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr}; +use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub}; use zokrates_ast::zir::types::UBitwidth; use zokrates_ast::zir::{ result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id, @@ -312,15 +314,30 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { .. }, ) if by == 0 => Ok(e), + ( + _, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::Number(T::from(0))) + } ( FieldElementExpression::Number(n), UExpression { inner: UExpressionInner::Value(by), .. }, - ) => Ok(FieldElementExpression::Number( - T::try_from(n.to_biguint().shl(by as usize)).unwrap(), - )), + ) => { + let two = BigUint::from(2usize); + let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); + + Ok(FieldElementExpression::Number( + T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask)) + .unwrap(), + )) + } (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), } } @@ -335,6 +352,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> { .. }, ) if by == 0 => Ok(e), + ( + _, + UExpression { + inner: UExpressionInner::Value(by), + .. + }, + ) if by as usize >= T::get_required_bits() => { + Ok(FieldElementExpression::Number(T::from(0))) + } ( FieldElementExpression::Number(n), UExpression { @@ -948,6 +974,115 @@ mod tests { ); } + #[test] + fn left_shift() { + let mut propagator = ZirPropagator::::default(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::identifier("a".into()), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(8))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(1)), + box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("14474011154664524427946373126085988481658748083205070504932198000989141204992").unwrap())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(3)), + box UExpressionInner::Value((Bn128Field::get_required_bits() - 3) as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::try_from_dec_str("10855508365998393320959779844564491361244061062403802878699148500741855903744").unwrap())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::LeftShift( + box FieldElementExpression::Number(Bn128Field::from(1)), + box UExpressionInner::Value((Bn128Field::get_required_bits()) as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + } + + #[test] + fn right_shift() { + let mut propagator = ZirPropagator::::default(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::identifier("a".into()), + box UExpressionInner::Value(0).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::identifier("a".into()), + box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(3)), + box UExpressionInner::Value(1 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box UExpressionInner::Value(2 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::from(2)), + box UExpressionInner::Value(4 as u128).annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::max_value()), + box UExpressionInner::Value((Bn128Field::get_required_bits() - 1) as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(1))) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::RightShift( + box FieldElementExpression::Number(Bn128Field::max_value()), + box UExpressionInner::Value(Bn128Field::get_required_bits() as u128) + .annotate(UBitwidth::B32), + )), + Ok(FieldElementExpression::Number(Bn128Field::from(0))) + ); + } + #[test] fn if_else() { let mut propagator = ZirPropagator::default();