1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #799 from Zokrates/refactor-constant-shift-checks

Fail gracefully on variable shifts
This commit is contained in:
Thibaut Schaeffer 2021-04-19 20:51:44 +02:00 committed by GitHub
commit b4f02e7db5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 111 additions and 100 deletions

View file

@ -0,0 +1,2 @@
def main(u32 a, u32 b) -> u32:
return a >> b

View file

@ -1263,18 +1263,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box FlatExpression::Sub(box new_left, box new_right),
))
}
UExpressionInner::LeftShift(box e, box by) => {
assert_eq!(by.bitwidth(), UBitwidth::B32);
let by = match by.into_inner() {
UExpressionInner::Value(n) => n,
by => unimplemented!(
"Variable shifts are unimplemented, found {} << {}",
e,
by.annotate(UBitwidth::B32)
),
};
UExpressionInner::LeftShift(box e, by) => {
let e = self.flatten_uint_expression(statements_flattened, e);
let e_bits = e.bits.unwrap();
@ -1292,18 +1281,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
.collect::<Vec<_>>(),
)
}
UExpressionInner::RightShift(box e, box by) => {
assert_eq!(by.bitwidth(), UBitwidth::B32);
let by = match by.into_inner() {
UExpressionInner::Value(n) => n,
by => unimplemented!(
"Variable shifts are unimplemented, found {} >> {}",
e,
by.annotate(UBitwidth::B32)
),
};
UExpressionInner::RightShift(box e, by) => {
let e = self.flatten_uint_expression(statements_flattened, e);
let e_bits = e.bits.unwrap();

View file

@ -855,15 +855,23 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
}
typed_absy::UExpressionInner::LeftShift(box e, box by) => {
let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
zir::UExpressionInner::LeftShift(box e, box by)
let by = match by.as_inner() {
typed_absy::UExpressionInner::Value(by) => by,
_ => unreachable!("static analysis should have made sure that this is constant"),
};
zir::UExpressionInner::LeftShift(box e, *by as u32)
}
typed_absy::UExpressionInner::RightShift(box e, box by) => {
let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
zir::UExpressionInner::RightShift(box e, box by)
let by = match by.as_inner() {
typed_absy::UExpressionInner::Value(by) => by,
_ => unreachable!("static analysis should have made sure that this is constant"),
};
zir::UExpressionInner::RightShift(box e, *by as u32)
}
typed_absy::UExpressionInner::Not(box e) => {
let e = f.fold_uint_expression(e);

View file

@ -10,6 +10,7 @@ mod flatten_complex_types;
mod propagation;
mod redefinition;
mod reducer;
mod shift_checker;
mod uint_optimizer;
mod unconstrained_vars;
mod variable_read_remover;
@ -20,6 +21,7 @@ use self::flatten_complex_types::Flattener;
use self::propagation::Propagator;
use self::redefinition::RedefinitionOptimizer;
use self::reducer::reduce_program;
use self::shift_checker::ShiftChecker;
use self::uint_optimizer::UintOptimizer;
use self::unconstrained_vars::UnconstrainedVariableDetector;
use self::variable_read_remover::VariableReadRemover;
@ -85,6 +87,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
let r = VariableReadRemover::apply(r);
// check array accesses are in bounds
let r = BoundsChecker::check(r).map_err(Error::from)?;
// detect non constant shifts
let r = ShiftChecker::check(r).map_err(Error::from)?;
// convert to zir, removing complex types
let zir = Flattener::flatten(r);
// optimize uint expressions

View file

@ -0,0 +1,55 @@
use crate::typed_absy::TypedProgram;
use crate::typed_absy::{
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
UExpressionInner,
};
use zokrates_field::Field;
pub struct ShiftChecker;
impl ShiftChecker {
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
ShiftChecker.fold_program(p)
}
}
pub type Error = String;
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
type Error = Error;
fn fold_uint_expression_inner(
&mut self,
bitwidth: UBitwidth,
e: UExpressionInner<'ast, T>,
) -> Result<UExpressionInner<'ast, T>, Error> {
match e {
UExpressionInner::LeftShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} << {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
UExpressionInner::RightShift(box e, box by) => {
let e = self.fold_uint_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
by => Err(format!(
"Cannot shift by a variable value, found `{} >> {}`",
e,
by.clone().annotate(UBitwidth::B32)
)),
}
}
e => fold_uint_expression_inner(self, bitwidth, e),
}
}
}

View file

@ -314,19 +314,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.annotate(range)
.with_max(range_max)
}
LeftShift(box e, box by) => {
LeftShift(box e, by) => {
// reduce both terms
let e = self.fold_uint_expression(e);
let by = self.fold_uint_expression(by);
let by_max: u128 = by
.metadata
.clone()
.unwrap()
.max
.to_dec_string()
.parse()
.unwrap();
let e_max: u128 = e
.metadata
.clone()
@ -336,20 +327,13 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.parse()
.unwrap();
let max = T::from((e_max << by_max) & (2_u128.pow(range as u32) - 1));
let max = T::from((e_max << by) & (2_u128.pow(range as u32) - 1));
UExpression::left_shift(force_reduce(e), force_reduce(by)).with_max(max)
UExpression::left_shift(force_reduce(e), by).with_max(max)
}
RightShift(box e, box by) => {
RightShift(box e, by) => {
// reduce both terms
let e = self.fold_uint_expression(e);
let by = self.fold_uint_expression(by);
// if we don't know the amount by which we shift, the most conservative case (which leads to the biggest value) is 0
let by_u = match by.as_inner() {
UExpressionInner::Value(by) => *by,
_ => 0,
};
let e_max: u128 = e
.metadata
@ -360,11 +344,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
.parse()
.unwrap();
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by_u;
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by;
let max = T::from(max);
UExpression::right_shift(force_reduce(e), force_reduce(by)).with_max(max)
UExpression::right_shift(force_reduce(e), by).with_max(max)
}
IfElse(box condition, box consequence, box alternative) => {
let condition = self.fold_boolean_expression(condition);
@ -694,30 +678,14 @@ mod tests {
#[test]
fn right_shift() {
// left argument in range, we reduce (no effect) and the max is the original max, as we could be shifting by 0
uint_test!(0xff_u32, true, 2, true, right_shift, 0xff_u32);
uint_test!(2, true, 2, true, right_shift, 2_u32);
// left argument out of range, we reduce and the max is the type max, shifted
uint_test!(
0xffffffffffff_u128,
true,
2,
true,
right_shift,
0xffffffff_u32
);
fn right_shift_test(e_max: u128, by: u32, output_max: u32) {
let left = e_with_max(e_max);
let right = UExpressionInner::Value(by as u128)
.annotate(crate::zir::types::UBitwidth::B32)
.with_max(by);
let right = by;
let left_expected = force_reduce(left.clone());
let right_expected = force_reduce(right.clone());
let right_expected = right;
assert_eq!(
UintOptimizer::new()
@ -733,25 +701,25 @@ mod tests {
#[test]
fn left_shift() {
uint_test!(0xff_u32, true, 2, true, left_shift, 0xff_u32 << 2);
uint_test!(
0xffffffff_u32,
true,
2,
true,
left_shift,
0xffffffff_u32 << 2
);
fn left_shift_test(e_max: u128, by: u32, output_max: u32) {
let left = e_with_max(e_max);
// left argument out of range, we reduce and the max is the type max, shifted
uint_test!(
0xffffffffffff_u128,
true,
2,
true,
left_shift,
0xffffffff_u32 << 2
)
let right = by;
let left_expected = force_reduce(left.clone());
let right_expected = right;
assert_eq!(
UintOptimizer::new()
.fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())),
UExpression::left_shift(left_expected, right_expected).with_max(output_max)
);
}
left_shift_test(0xff_u128, 2, 0xff << 2);
left_shift_test(2, 2, 2 << 2);
left_shift_test(0xffffffffffff_u128, 2, 0xffffffff << 2);
}
#[test]

View file

@ -308,17 +308,15 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
UExpressionInner::Or(box left, box right)
}
UExpressionInner::LeftShift(box e, box by) => {
UExpressionInner::LeftShift(box e, by) => {
let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
UExpressionInner::LeftShift(box e, box by)
UExpressionInner::LeftShift(box e, by)
}
UExpressionInner::RightShift(box e, box by) => {
UExpressionInner::RightShift(box e, by) => {
let e = f.fold_uint_expression(e);
let by = f.fold_uint_expression(by);
UExpressionInner::RightShift(box e, box by)
UExpressionInner::RightShift(box e, by)
}
UExpressionInner::Not(box e) => {
let e = f.fold_uint_expression(e);

View file

@ -57,16 +57,14 @@ impl<'ast, T: Field> UExpression<'ast, T> {
UExpressionInner::And(box self, box other).annotate(bitwidth)
}
pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> {
pub fn left_shift(self, by: u32) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(by.bitwidth(), UBitwidth::B32);
UExpressionInner::LeftShift(box self, box by).annotate(bitwidth)
UExpressionInner::LeftShift(box self, by).annotate(bitwidth)
}
pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> {
pub fn right_shift(self, by: u32) -> UExpression<'ast, T> {
let bitwidth = self.bitwidth;
assert_eq!(by.bitwidth(), UBitwidth::B32);
UExpressionInner::RightShift(box self, box by).annotate(bitwidth)
UExpressionInner::RightShift(box self, by).annotate(bitwidth)
}
}
@ -170,8 +168,8 @@ pub enum UExpressionInner<'ast, T> {
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
LeftShift(Box<UExpression<'ast, T>>, u32),
RightShift(Box<UExpression<'ast, T>>, u32),
Not(Box<UExpression<'ast, T>>),
IfElse(
Box<BooleanExpression<'ast, T>>,