Merge pull request #799 from Zokrates/refactor-constant-shift-checks
Fail gracefully on variable shifts
This commit is contained in:
commit
b4f02e7db5
8 changed files with 111 additions and 100 deletions
2
zokrates_cli/examples/compile_errors/variable_shift.zok
Normal file
2
zokrates_cli/examples/compile_errors/variable_shift.zok
Normal file
|
@ -0,0 +1,2 @@
|
|||
def main(u32 a, u32 b) -> u32:
|
||||
return a >> b
|
|
@ -1263,18 +1263,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
box FlatExpression::Sub(box new_left, box new_right),
|
||||
))
|
||||
}
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
||||
|
||||
let by = match by.into_inner() {
|
||||
UExpressionInner::Value(n) => n,
|
||||
by => unimplemented!(
|
||||
"Variable shifts are unimplemented, found {} << {}",
|
||||
e,
|
||||
by.annotate(UBitwidth::B32)
|
||||
),
|
||||
};
|
||||
|
||||
UExpressionInner::LeftShift(box e, by) => {
|
||||
let e = self.flatten_uint_expression(statements_flattened, e);
|
||||
|
||||
let e_bits = e.bits.unwrap();
|
||||
|
@ -1292,18 +1281,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
||||
|
||||
let by = match by.into_inner() {
|
||||
UExpressionInner::Value(n) => n,
|
||||
by => unimplemented!(
|
||||
"Variable shifts are unimplemented, found {} >> {}",
|
||||
e,
|
||||
by.annotate(UBitwidth::B32)
|
||||
),
|
||||
};
|
||||
|
||||
UExpressionInner::RightShift(box e, by) => {
|
||||
let e = self.flatten_uint_expression(statements_flattened, e);
|
||||
|
||||
let e_bits = e.bits.unwrap();
|
||||
|
|
|
@ -855,15 +855,23 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
|||
}
|
||||
typed_absy::UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
zir::UExpressionInner::LeftShift(box e, box by)
|
||||
let by = match by.as_inner() {
|
||||
typed_absy::UExpressionInner::Value(by) => by,
|
||||
_ => unreachable!("static analysis should have made sure that this is constant"),
|
||||
};
|
||||
|
||||
zir::UExpressionInner::LeftShift(box e, *by as u32)
|
||||
}
|
||||
typed_absy::UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
zir::UExpressionInner::RightShift(box e, box by)
|
||||
let by = match by.as_inner() {
|
||||
typed_absy::UExpressionInner::Value(by) => by,
|
||||
_ => unreachable!("static analysis should have made sure that this is constant"),
|
||||
};
|
||||
|
||||
zir::UExpressionInner::RightShift(box e, *by as u32)
|
||||
}
|
||||
typed_absy::UExpressionInner::Not(box e) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
|
|
|
@ -10,6 +10,7 @@ mod flatten_complex_types;
|
|||
mod propagation;
|
||||
mod redefinition;
|
||||
mod reducer;
|
||||
mod shift_checker;
|
||||
mod uint_optimizer;
|
||||
mod unconstrained_vars;
|
||||
mod variable_read_remover;
|
||||
|
@ -20,6 +21,7 @@ use self::flatten_complex_types::Flattener;
|
|||
use self::propagation::Propagator;
|
||||
use self::redefinition::RedefinitionOptimizer;
|
||||
use self::reducer::reduce_program;
|
||||
use self::shift_checker::ShiftChecker;
|
||||
use self::uint_optimizer::UintOptimizer;
|
||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||
use self::variable_read_remover::VariableReadRemover;
|
||||
|
@ -85,6 +87,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
|||
let r = VariableReadRemover::apply(r);
|
||||
// check array accesses are in bounds
|
||||
let r = BoundsChecker::check(r).map_err(Error::from)?;
|
||||
// detect non constant shifts
|
||||
let r = ShiftChecker::check(r).map_err(Error::from)?;
|
||||
// convert to zir, removing complex types
|
||||
let zir = Flattener::flatten(r);
|
||||
// optimize uint expressions
|
||||
|
|
55
zokrates_core/src/static_analysis/shift_checker.rs
Normal file
55
zokrates_core/src/static_analysis/shift_checker.rs
Normal file
|
@ -0,0 +1,55 @@
|
|||
use crate::typed_absy::TypedProgram;
|
||||
use crate::typed_absy::{
|
||||
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
|
||||
UExpressionInner,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
pub struct ShiftChecker;
|
||||
|
||||
impl ShiftChecker {
|
||||
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||
ShiftChecker.fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Error = String;
|
||||
|
||||
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
|
||||
type Error = Error;
|
||||
|
||||
fn fold_uint_expression_inner(
|
||||
&mut self,
|
||||
bitwidth: UBitwidth,
|
||||
e: UExpressionInner<'ast, T>,
|
||||
) -> Result<UExpressionInner<'ast, T>, Error> {
|
||||
match e {
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
|
||||
by => Err(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
)),
|
||||
}
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
let e = self.fold_uint_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
|
||||
by => Err(format!(
|
||||
"Cannot shift by a variable value, found `{} >> {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -314,19 +314,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.annotate(range)
|
||||
.with_max(range_max)
|
||||
}
|
||||
LeftShift(box e, box by) => {
|
||||
LeftShift(box e, by) => {
|
||||
// reduce both terms
|
||||
let e = self.fold_uint_expression(e);
|
||||
let by = self.fold_uint_expression(by);
|
||||
|
||||
let by_max: u128 = by
|
||||
.metadata
|
||||
.clone()
|
||||
.unwrap()
|
||||
.max
|
||||
.to_dec_string()
|
||||
.parse()
|
||||
.unwrap();
|
||||
let e_max: u128 = e
|
||||
.metadata
|
||||
.clone()
|
||||
|
@ -336,20 +327,13 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.parse()
|
||||
.unwrap();
|
||||
|
||||
let max = T::from((e_max << by_max) & (2_u128.pow(range as u32) - 1));
|
||||
let max = T::from((e_max << by) & (2_u128.pow(range as u32) - 1));
|
||||
|
||||
UExpression::left_shift(force_reduce(e), force_reduce(by)).with_max(max)
|
||||
UExpression::left_shift(force_reduce(e), by).with_max(max)
|
||||
}
|
||||
RightShift(box e, box by) => {
|
||||
RightShift(box e, by) => {
|
||||
// reduce both terms
|
||||
let e = self.fold_uint_expression(e);
|
||||
let by = self.fold_uint_expression(by);
|
||||
|
||||
// if we don't know the amount by which we shift, the most conservative case (which leads to the biggest value) is 0
|
||||
let by_u = match by.as_inner() {
|
||||
UExpressionInner::Value(by) => *by,
|
||||
_ => 0,
|
||||
};
|
||||
|
||||
let e_max: u128 = e
|
||||
.metadata
|
||||
|
@ -360,11 +344,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
.parse()
|
||||
.unwrap();
|
||||
|
||||
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by_u;
|
||||
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by;
|
||||
|
||||
let max = T::from(max);
|
||||
|
||||
UExpression::right_shift(force_reduce(e), force_reduce(by)).with_max(max)
|
||||
UExpression::right_shift(force_reduce(e), by).with_max(max)
|
||||
}
|
||||
IfElse(box condition, box consequence, box alternative) => {
|
||||
let condition = self.fold_boolean_expression(condition);
|
||||
|
@ -694,30 +678,14 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn right_shift() {
|
||||
// left argument in range, we reduce (no effect) and the max is the original max, as we could be shifting by 0
|
||||
uint_test!(0xff_u32, true, 2, true, right_shift, 0xff_u32);
|
||||
uint_test!(2, true, 2, true, right_shift, 2_u32);
|
||||
|
||||
// left argument out of range, we reduce and the max is the type max, shifted
|
||||
uint_test!(
|
||||
0xffffffffffff_u128,
|
||||
true,
|
||||
2,
|
||||
true,
|
||||
right_shift,
|
||||
0xffffffff_u32
|
||||
);
|
||||
|
||||
fn right_shift_test(e_max: u128, by: u32, output_max: u32) {
|
||||
let left = e_with_max(e_max);
|
||||
|
||||
let right = UExpressionInner::Value(by as u128)
|
||||
.annotate(crate::zir::types::UBitwidth::B32)
|
||||
.with_max(by);
|
||||
let right = by;
|
||||
|
||||
let left_expected = force_reduce(left.clone());
|
||||
|
||||
let right_expected = force_reduce(right.clone());
|
||||
let right_expected = right;
|
||||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
|
@ -733,25 +701,25 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn left_shift() {
|
||||
uint_test!(0xff_u32, true, 2, true, left_shift, 0xff_u32 << 2);
|
||||
uint_test!(
|
||||
0xffffffff_u32,
|
||||
true,
|
||||
2,
|
||||
true,
|
||||
left_shift,
|
||||
0xffffffff_u32 << 2
|
||||
);
|
||||
fn left_shift_test(e_max: u128, by: u32, output_max: u32) {
|
||||
let left = e_with_max(e_max);
|
||||
|
||||
// left argument out of range, we reduce and the max is the type max, shifted
|
||||
uint_test!(
|
||||
0xffffffffffff_u128,
|
||||
true,
|
||||
2,
|
||||
true,
|
||||
left_shift,
|
||||
0xffffffff_u32 << 2
|
||||
)
|
||||
let right = by;
|
||||
|
||||
let left_expected = force_reduce(left.clone());
|
||||
|
||||
let right_expected = right;
|
||||
|
||||
assert_eq!(
|
||||
UintOptimizer::new()
|
||||
.fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())),
|
||||
UExpression::left_shift(left_expected, right_expected).with_max(output_max)
|
||||
);
|
||||
}
|
||||
|
||||
left_shift_test(0xff_u128, 2, 0xff << 2);
|
||||
left_shift_test(2, 2, 2 << 2);
|
||||
left_shift_test(0xffffffffffff_u128, 2, 0xffffffff << 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -308,17 +308,15 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
UExpressionInner::Or(box left, box right)
|
||||
}
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
UExpressionInner::LeftShift(box e, by) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
UExpressionInner::LeftShift(box e, box by)
|
||||
UExpressionInner::LeftShift(box e, by)
|
||||
}
|
||||
UExpressionInner::RightShift(box e, box by) => {
|
||||
UExpressionInner::RightShift(box e, by) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
UExpressionInner::RightShift(box e, box by)
|
||||
UExpressionInner::RightShift(box e, by)
|
||||
}
|
||||
UExpressionInner::Not(box e) => {
|
||||
let e = f.fold_uint_expression(e);
|
||||
|
|
|
@ -57,16 +57,14 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
|||
UExpressionInner::And(box self, box other).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
pub fn left_shift(self, by: u32) -> UExpression<'ast, T> {
|
||||
let bitwidth = self.bitwidth;
|
||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
||||
UExpressionInner::LeftShift(box self, box by).annotate(bitwidth)
|
||||
UExpressionInner::LeftShift(box self, by).annotate(bitwidth)
|
||||
}
|
||||
|
||||
pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
||||
pub fn right_shift(self, by: u32) -> UExpression<'ast, T> {
|
||||
let bitwidth = self.bitwidth;
|
||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
||||
UExpressionInner::RightShift(box self, box by).annotate(bitwidth)
|
||||
UExpressionInner::RightShift(box self, by).annotate(bitwidth)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,8 +168,8 @@ pub enum UExpressionInner<'ast, T> {
|
|||
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
LeftShift(Box<UExpression<'ast, T>>, u32),
|
||||
RightShift(Box<UExpression<'ast, T>>, u32),
|
||||
Not(Box<UExpression<'ast, T>>),
|
||||
IfElse(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
|
|
Loading…
Reference in a new issue