1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

Merge pull request #799 from Zokrates/refactor-constant-shift-checks

Fail gracefully on variable shifts
This commit is contained in:
Thibaut Schaeffer 2021-04-19 20:51:44 +02:00 committed by GitHub
commit b4f02e7db5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 111 additions and 100 deletions

View file

@ -0,0 +1,2 @@
def main(u32 a, u32 b) -> u32:
return a >> b

View file

@ -1263,18 +1263,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box FlatExpression::Sub(box new_left, box new_right), box FlatExpression::Sub(box new_left, box new_right),
)) ))
} }
UExpressionInner::LeftShift(box e, box by) => { UExpressionInner::LeftShift(box e, by) => {
assert_eq!(by.bitwidth(), UBitwidth::B32);
let by = match by.into_inner() {
UExpressionInner::Value(n) => n,
by => unimplemented!(
"Variable shifts are unimplemented, found {} << {}",
e,
by.annotate(UBitwidth::B32)
),
};
let e = self.flatten_uint_expression(statements_flattened, e); let e = self.flatten_uint_expression(statements_flattened, e);
let e_bits = e.bits.unwrap(); let e_bits = e.bits.unwrap();
@ -1292,18 +1281,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
) )
} }
UExpressionInner::RightShift(box e, box by) => { UExpressionInner::RightShift(box e, by) => {
assert_eq!(by.bitwidth(), UBitwidth::B32);
let by = match by.into_inner() {
UExpressionInner::Value(n) => n,
by => unimplemented!(
"Variable shifts are unimplemented, found {} >> {}",
e,
by.annotate(UBitwidth::B32)
),
};
let e = self.flatten_uint_expression(statements_flattened, e); let e = self.flatten_uint_expression(statements_flattened, e);
let e_bits = e.bits.unwrap(); let e_bits = e.bits.unwrap();

View file

@ -855,15 +855,23 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
} }
typed_absy::UExpressionInner::LeftShift(box e, box by) => { typed_absy::UExpressionInner::LeftShift(box e, box by) => {
let e = f.fold_uint_expression(e); let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
zir::UExpressionInner::LeftShift(box e, box by) let by = match by.as_inner() {
typed_absy::UExpressionInner::Value(by) => by,
_ => unreachable!("static analysis should have made sure that this is constant"),
};
zir::UExpressionInner::LeftShift(box e, *by as u32)
} }
typed_absy::UExpressionInner::RightShift(box e, box by) => { typed_absy::UExpressionInner::RightShift(box e, box by) => {
let e = f.fold_uint_expression(e); let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
zir::UExpressionInner::RightShift(box e, box by) let by = match by.as_inner() {
typed_absy::UExpressionInner::Value(by) => by,
_ => unreachable!("static analysis should have made sure that this is constant"),
};
zir::UExpressionInner::RightShift(box e, *by as u32)
} }
typed_absy::UExpressionInner::Not(box e) => { typed_absy::UExpressionInner::Not(box e) => {
let e = f.fold_uint_expression(e); let e = f.fold_uint_expression(e);

View file

@ -10,6 +10,7 @@ mod flatten_complex_types;
mod propagation; mod propagation;
mod redefinition; mod redefinition;
mod reducer; mod reducer;
mod shift_checker;
mod uint_optimizer; mod uint_optimizer;
mod unconstrained_vars; mod unconstrained_vars;
mod variable_read_remover; mod variable_read_remover;
@ -20,6 +21,7 @@ use self::flatten_complex_types::Flattener;
use self::propagation::Propagator; use self::propagation::Propagator;
use self::redefinition::RedefinitionOptimizer; use self::redefinition::RedefinitionOptimizer;
use self::reducer::reduce_program; use self::reducer::reduce_program;
use self::shift_checker::ShiftChecker;
use self::uint_optimizer::UintOptimizer; use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector; use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_read_remover::VariableReadRemover; use self::variable_read_remover::VariableReadRemover;
@ -85,6 +87,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
let r = VariableReadRemover::apply(r); let r = VariableReadRemover::apply(r);
// check array accesses are in bounds // check array accesses are in bounds
let r = BoundsChecker::check(r).map_err(Error::from)?; let r = BoundsChecker::check(r).map_err(Error::from)?;
// detect non constant shifts
let r = ShiftChecker::check(r).map_err(Error::from)?;
// convert to zir, removing complex types // convert to zir, removing complex types
let zir = Flattener::flatten(r); let zir = Flattener::flatten(r);
// optimize uint expressions // optimize uint expressions

View file

@ -0,0 +1,55 @@
use crate::typed_absy::TypedProgram;
use crate::typed_absy::{
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
UExpressionInner,
};
use zokrates_field::Field;
pub struct ShiftChecker;
impl ShiftChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
ShiftChecker.fold_program(p)
}
}
pub type Error = String;
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
type Error = Error;
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Error> {
match e {
UExpressionInner::LeftShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} << {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
UExpressionInner::RightShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} >> {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
}

View file

@ -314,19 +314,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.annotate(range) .annotate(range)
.with_max(range_max) .with_max(range_max)
} }
LeftShift(box e, box by) => { LeftShift(box e, by) => {
// reduce both terms // reduce both terms
let e = self.fold_uint_expression(e); let e = self.fold_uint_expression(e);
let by = self.fold_uint_expression(by);
let by_max: u128 = by
.metadata
.clone()
.unwrap()
.max
.to_dec_string()
.parse()
.unwrap();
let e_max: u128 = e let e_max: u128 = e
.metadata .metadata
.clone() .clone()
@ -336,20 +327,13 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.parse() .parse()
.unwrap(); .unwrap();
let max = T::from((e_max << by_max) & (2_u128.pow(range as u32) - 1)); let max = T::from((e_max << by) & (2_u128.pow(range as u32) - 1));
UExpression::left_shift(force_reduce(e), force_reduce(by)).with_max(max) UExpression::left_shift(force_reduce(e), by).with_max(max)
} }
RightShift(box e, box by) => { RightShift(box e, by) => {
// reduce both terms // reduce both terms
let e = self.fold_uint_expression(e); let e = self.fold_uint_expression(e);
let by = self.fold_uint_expression(by);
// if we don't know the amount by which we shift, the most conservative case (which leads to the biggest value) is 0
let by_u = match by.as_inner() {
UExpressionInner::Value(by) => *by,
_ => 0,
};
let e_max: u128 = e let e_max: u128 = e
.metadata .metadata
@ -360,11 +344,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.parse() .parse()
.unwrap(); .unwrap();
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by_u; let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by;
let max = T::from(max); let max = T::from(max);
UExpression::right_shift(force_reduce(e), force_reduce(by)).with_max(max) UExpression::right_shift(force_reduce(e), by).with_max(max)
} }
IfElse(box condition, box consequence, box alternative) => { IfElse(box condition, box consequence, box alternative) => {
let condition = self.fold_boolean_expression(condition); let condition = self.fold_boolean_expression(condition);
@ -694,30 +678,14 @@ mod tests {
#[test] #[test]
fn right_shift() { fn right_shift() {
// left argument in range, we reduce (no effect) and the max is the original max, as we could be shifting by 0
uint_test!(0xff_u32, true, 2, true, right_shift, 0xff_u32);
uint_test!(2, true, 2, true, right_shift, 2_u32);
// left argument out of range, we reduce and the max is the type max, shifted
uint_test!(
0xffffffffffff_u128,
true,
2,
true,
right_shift,
0xffffffff_u32
);
fn right_shift_test(e_max: u128, by: u32, output_max: u32) { fn right_shift_test(e_max: u128, by: u32, output_max: u32) {
let left = e_with_max(e_max); let left = e_with_max(e_max);
let right = UExpressionInner::Value(by as u128) let right = by;
.annotate(crate::zir::types::UBitwidth::B32)
.with_max(by);
let left_expected = force_reduce(left.clone()); let left_expected = force_reduce(left.clone());
let right_expected = force_reduce(right.clone()); let right_expected = right;
assert_eq!( assert_eq!(
UintOptimizer::new() UintOptimizer::new()
@ -733,25 +701,25 @@ mod tests {
#[test] #[test]
fn left_shift() { fn left_shift() {
uint_test!(0xff_u32, true, 2, true, left_shift, 0xff_u32 << 2); fn left_shift_test(e_max: u128, by: u32, output_max: u32) {
uint_test!( let left = e_with_max(e_max);
0xffffffff_u32,
true,
2,
true,
left_shift,
0xffffffff_u32 << 2
);
// left argument out of range, we reduce and the max is the type max, shifted let right = by;
uint_test!(
0xffffffffffff_u128, let left_expected = force_reduce(left.clone());
true,
2, let right_expected = right;
true,
left_shift, assert_eq!(
0xffffffff_u32 << 2 UintOptimizer::new()
) .fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())),
UExpression::left_shift(left_expected, right_expected).with_max(output_max)
);
}
left_shift_test(0xff_u128, 2, 0xff << 2);
left_shift_test(2, 2, 2 << 2);
left_shift_test(0xffffffffffff_u128, 2, 0xffffffff << 2);
} }
#[test] #[test]

View file

@ -308,17 +308,15 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Or(box left, box right) UExpressionInner::Or(box left, box right)
} }
UExpressionInner::LeftShift(box e, box by) => { UExpressionInner::LeftShift(box e, by) => {
let e = f.fold_uint_expression(e); let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
UExpressionInner::LeftShift(box e, box by) UExpressionInner::LeftShift(box e, by)
} }
UExpressionInner::RightShift(box e, box by) => { UExpressionInner::RightShift(box e, by) => {
let e = f.fold_uint_expression(e); let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
UExpressionInner::RightShift(box e, box by) UExpressionInner::RightShift(box e, by)
} }
UExpressionInner::Not(box e) => { UExpressionInner::Not(box e) => {
let e = f.fold_uint_expression(e); let e = f.fold_uint_expression(e);

View file

@ -57,16 +57,14 @@ impl<'ast, T: Field> UExpression<'ast, T> {
UExpressionInner::And(box self, box other).annotate(bitwidth) UExpressionInner::And(box self, box other).annotate(bitwidth)
} }
pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { pub fn left_shift(self, by: u32) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth; let bitwidth = self.bitwidth;
assert_eq!(by.bitwidth(), UBitwidth::B32); UExpressionInner::LeftShift(box self, by).annotate(bitwidth)
UExpressionInner::LeftShift(box self, box by).annotate(bitwidth)
} }
pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> { pub fn right_shift(self, by: u32) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth; let bitwidth = self.bitwidth;
assert_eq!(by.bitwidth(), UBitwidth::B32); UExpressionInner::RightShift(box self, by).annotate(bitwidth)
UExpressionInner::RightShift(box self, box by).annotate(bitwidth)
} }
} }
@ -170,8 +168,8 @@ pub enum UExpressionInner<'ast, T> {
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>), Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>), And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>), Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>), LeftShift(Box<UExpression<'ast, T>>, u32),
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>), RightShift(Box<UExpression<'ast, T>>, u32),
Not(Box<UExpression<'ast, T>>), Not(Box<UExpression<'ast, T>>),
IfElse( IfElse(
Box<BooleanExpression<'ast, T>>, Box<BooleanExpression<'ast, T>>,