Merge pull request #799 from Zokrates/refactor-constant-shift-checks
Fail gracefully on variable shifts
This commit is contained in:
commit
b4f02e7db5
8 changed files with 111 additions and 100 deletions
2
zokrates_cli/examples/compile_errors/variable_shift.zok
Normal file
2
zokrates_cli/examples/compile_errors/variable_shift.zok
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
def main(u32 a, u32 b) -> u32:
|
||||||
|
return a >> b
|
|
@ -1263,18 +1263,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
box FlatExpression::Sub(box new_left, box new_right),
|
box FlatExpression::Sub(box new_left, box new_right),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
UExpressionInner::LeftShift(box e, box by) => {
|
UExpressionInner::LeftShift(box e, by) => {
|
||||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
|
||||||
|
|
||||||
let by = match by.into_inner() {
|
|
||||||
UExpressionInner::Value(n) => n,
|
|
||||||
by => unimplemented!(
|
|
||||||
"Variable shifts are unimplemented, found {} << {}",
|
|
||||||
e,
|
|
||||||
by.annotate(UBitwidth::B32)
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
let e = self.flatten_uint_expression(statements_flattened, e);
|
let e = self.flatten_uint_expression(statements_flattened, e);
|
||||||
|
|
||||||
let e_bits = e.bits.unwrap();
|
let e_bits = e.bits.unwrap();
|
||||||
|
@ -1292,18 +1281,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
UExpressionInner::RightShift(box e, box by) => {
|
UExpressionInner::RightShift(box e, by) => {
|
||||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
|
||||||
|
|
||||||
let by = match by.into_inner() {
|
|
||||||
UExpressionInner::Value(n) => n,
|
|
||||||
by => unimplemented!(
|
|
||||||
"Variable shifts are unimplemented, found {} >> {}",
|
|
||||||
e,
|
|
||||||
by.annotate(UBitwidth::B32)
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
let e = self.flatten_uint_expression(statements_flattened, e);
|
let e = self.flatten_uint_expression(statements_flattened, e);
|
||||||
|
|
||||||
let e_bits = e.bits.unwrap();
|
let e_bits = e.bits.unwrap();
|
||||||
|
|
|
@ -855,15 +855,23 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
||||||
}
|
}
|
||||||
typed_absy::UExpressionInner::LeftShift(box e, box by) => {
|
typed_absy::UExpressionInner::LeftShift(box e, box by) => {
|
||||||
let e = f.fold_uint_expression(e);
|
let e = f.fold_uint_expression(e);
|
||||||
let by = f.fold_uint_expression(by);
|
|
||||||
|
|
||||||
zir::UExpressionInner::LeftShift(box e, box by)
|
let by = match by.as_inner() {
|
||||||
|
typed_absy::UExpressionInner::Value(by) => by,
|
||||||
|
_ => unreachable!("static analysis should have made sure that this is constant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
zir::UExpressionInner::LeftShift(box e, *by as u32)
|
||||||
}
|
}
|
||||||
typed_absy::UExpressionInner::RightShift(box e, box by) => {
|
typed_absy::UExpressionInner::RightShift(box e, box by) => {
|
||||||
let e = f.fold_uint_expression(e);
|
let e = f.fold_uint_expression(e);
|
||||||
let by = f.fold_uint_expression(by);
|
|
||||||
|
|
||||||
zir::UExpressionInner::RightShift(box e, box by)
|
let by = match by.as_inner() {
|
||||||
|
typed_absy::UExpressionInner::Value(by) => by,
|
||||||
|
_ => unreachable!("static analysis should have made sure that this is constant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
zir::UExpressionInner::RightShift(box e, *by as u32)
|
||||||
}
|
}
|
||||||
typed_absy::UExpressionInner::Not(box e) => {
|
typed_absy::UExpressionInner::Not(box e) => {
|
||||||
let e = f.fold_uint_expression(e);
|
let e = f.fold_uint_expression(e);
|
||||||
|
|
|
@ -10,6 +10,7 @@ mod flatten_complex_types;
|
||||||
mod propagation;
|
mod propagation;
|
||||||
mod redefinition;
|
mod redefinition;
|
||||||
mod reducer;
|
mod reducer;
|
||||||
|
mod shift_checker;
|
||||||
mod uint_optimizer;
|
mod uint_optimizer;
|
||||||
mod unconstrained_vars;
|
mod unconstrained_vars;
|
||||||
mod variable_read_remover;
|
mod variable_read_remover;
|
||||||
|
@ -20,6 +21,7 @@ use self::flatten_complex_types::Flattener;
|
||||||
use self::propagation::Propagator;
|
use self::propagation::Propagator;
|
||||||
use self::redefinition::RedefinitionOptimizer;
|
use self::redefinition::RedefinitionOptimizer;
|
||||||
use self::reducer::reduce_program;
|
use self::reducer::reduce_program;
|
||||||
|
use self::shift_checker::ShiftChecker;
|
||||||
use self::uint_optimizer::UintOptimizer;
|
use self::uint_optimizer::UintOptimizer;
|
||||||
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
use self::unconstrained_vars::UnconstrainedVariableDetector;
|
||||||
use self::variable_read_remover::VariableReadRemover;
|
use self::variable_read_remover::VariableReadRemover;
|
||||||
|
@ -85,6 +87,8 @@ impl<'ast, T: Field> TypedProgram<'ast, T> {
|
||||||
let r = VariableReadRemover::apply(r);
|
let r = VariableReadRemover::apply(r);
|
||||||
// check array accesses are in bounds
|
// check array accesses are in bounds
|
||||||
let r = BoundsChecker::check(r).map_err(Error::from)?;
|
let r = BoundsChecker::check(r).map_err(Error::from)?;
|
||||||
|
// detect non constant shifts
|
||||||
|
let r = ShiftChecker::check(r).map_err(Error::from)?;
|
||||||
// convert to zir, removing complex types
|
// convert to zir, removing complex types
|
||||||
let zir = Flattener::flatten(r);
|
let zir = Flattener::flatten(r);
|
||||||
// optimize uint expressions
|
// optimize uint expressions
|
||||||
|
|
55
zokrates_core/src/static_analysis/shift_checker.rs
Normal file
55
zokrates_core/src/static_analysis/shift_checker.rs
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
use crate::typed_absy::TypedProgram;
|
||||||
|
use crate::typed_absy::{
|
||||||
|
result_folder::fold_uint_expression_inner, result_folder::ResultFolder, UBitwidth,
|
||||||
|
UExpressionInner,
|
||||||
|
};
|
||||||
|
use zokrates_field::Field;
|
||||||
|
pub struct ShiftChecker;
|
||||||
|
|
||||||
|
impl ShiftChecker {
|
||||||
|
pub fn check<T: Field>(p: TypedProgram<T>) -> Result<TypedProgram<T>, Error> {
|
||||||
|
ShiftChecker.fold_program(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Error = String;
|
||||||
|
|
||||||
|
impl<'ast, T: Field> ResultFolder<'ast, T> for ShiftChecker {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn fold_uint_expression_inner(
|
||||||
|
&mut self,
|
||||||
|
bitwidth: UBitwidth,
|
||||||
|
e: UExpressionInner<'ast, T>,
|
||||||
|
) -> Result<UExpressionInner<'ast, T>, Error> {
|
||||||
|
match e {
|
||||||
|
UExpressionInner::LeftShift(box e, box by) => {
|
||||||
|
let e = self.fold_uint_expression(e)?;
|
||||||
|
let by = self.fold_uint_expression(by)?;
|
||||||
|
|
||||||
|
match by.as_inner() {
|
||||||
|
UExpressionInner::Value(_) => Ok(UExpressionInner::LeftShift(box e, box by)),
|
||||||
|
by => Err(format!(
|
||||||
|
"Cannot shift by a variable value, found `{} << {}`",
|
||||||
|
e,
|
||||||
|
by.clone().annotate(UBitwidth::B32)
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
UExpressionInner::RightShift(box e, box by) => {
|
||||||
|
let e = self.fold_uint_expression(e)?;
|
||||||
|
let by = self.fold_uint_expression(by)?;
|
||||||
|
|
||||||
|
match by.as_inner() {
|
||||||
|
UExpressionInner::Value(_) => Ok(UExpressionInner::RightShift(box e, box by)),
|
||||||
|
by => Err(format!(
|
||||||
|
"Cannot shift by a variable value, found `{} >> {}`",
|
||||||
|
e,
|
||||||
|
by.clone().annotate(UBitwidth::B32)
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e => fold_uint_expression_inner(self, bitwidth, e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -314,19 +314,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
||||||
.annotate(range)
|
.annotate(range)
|
||||||
.with_max(range_max)
|
.with_max(range_max)
|
||||||
}
|
}
|
||||||
LeftShift(box e, box by) => {
|
LeftShift(box e, by) => {
|
||||||
// reduce both terms
|
// reduce both terms
|
||||||
let e = self.fold_uint_expression(e);
|
let e = self.fold_uint_expression(e);
|
||||||
let by = self.fold_uint_expression(by);
|
|
||||||
|
|
||||||
let by_max: u128 = by
|
|
||||||
.metadata
|
|
||||||
.clone()
|
|
||||||
.unwrap()
|
|
||||||
.max
|
|
||||||
.to_dec_string()
|
|
||||||
.parse()
|
|
||||||
.unwrap();
|
|
||||||
let e_max: u128 = e
|
let e_max: u128 = e
|
||||||
.metadata
|
.metadata
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -336,20 +327,13 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let max = T::from((e_max << by_max) & (2_u128.pow(range as u32) - 1));
|
let max = T::from((e_max << by) & (2_u128.pow(range as u32) - 1));
|
||||||
|
|
||||||
UExpression::left_shift(force_reduce(e), force_reduce(by)).with_max(max)
|
UExpression::left_shift(force_reduce(e), by).with_max(max)
|
||||||
}
|
}
|
||||||
RightShift(box e, box by) => {
|
RightShift(box e, by) => {
|
||||||
// reduce both terms
|
// reduce both terms
|
||||||
let e = self.fold_uint_expression(e);
|
let e = self.fold_uint_expression(e);
|
||||||
let by = self.fold_uint_expression(by);
|
|
||||||
|
|
||||||
// if we don't know the amount by which we shift, the most conservative case (which leads to the biggest value) is 0
|
|
||||||
let by_u = match by.as_inner() {
|
|
||||||
UExpressionInner::Value(by) => *by,
|
|
||||||
_ => 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
let e_max: u128 = e
|
let e_max: u128 = e
|
||||||
.metadata
|
.metadata
|
||||||
|
@ -360,11 +344,11 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by_u;
|
let max = (e_max & (2_u128.pow(range as u32) - 1)) >> by;
|
||||||
|
|
||||||
let max = T::from(max);
|
let max = T::from(max);
|
||||||
|
|
||||||
UExpression::right_shift(force_reduce(e), force_reduce(by)).with_max(max)
|
UExpression::right_shift(force_reduce(e), by).with_max(max)
|
||||||
}
|
}
|
||||||
IfElse(box condition, box consequence, box alternative) => {
|
IfElse(box condition, box consequence, box alternative) => {
|
||||||
let condition = self.fold_boolean_expression(condition);
|
let condition = self.fold_boolean_expression(condition);
|
||||||
|
@ -694,30 +678,14 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn right_shift() {
|
fn right_shift() {
|
||||||
// left argument in range, we reduce (no effect) and the max is the original max, as we could be shifting by 0
|
|
||||||
uint_test!(0xff_u32, true, 2, true, right_shift, 0xff_u32);
|
|
||||||
uint_test!(2, true, 2, true, right_shift, 2_u32);
|
|
||||||
|
|
||||||
// left argument out of range, we reduce and the max is the type max, shifted
|
|
||||||
uint_test!(
|
|
||||||
0xffffffffffff_u128,
|
|
||||||
true,
|
|
||||||
2,
|
|
||||||
true,
|
|
||||||
right_shift,
|
|
||||||
0xffffffff_u32
|
|
||||||
);
|
|
||||||
|
|
||||||
fn right_shift_test(e_max: u128, by: u32, output_max: u32) {
|
fn right_shift_test(e_max: u128, by: u32, output_max: u32) {
|
||||||
let left = e_with_max(e_max);
|
let left = e_with_max(e_max);
|
||||||
|
|
||||||
let right = UExpressionInner::Value(by as u128)
|
let right = by;
|
||||||
.annotate(crate::zir::types::UBitwidth::B32)
|
|
||||||
.with_max(by);
|
|
||||||
|
|
||||||
let left_expected = force_reduce(left.clone());
|
let left_expected = force_reduce(left.clone());
|
||||||
|
|
||||||
let right_expected = force_reduce(right.clone());
|
let right_expected = right;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
UintOptimizer::new()
|
UintOptimizer::new()
|
||||||
|
@ -733,25 +701,25 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn left_shift() {
|
fn left_shift() {
|
||||||
uint_test!(0xff_u32, true, 2, true, left_shift, 0xff_u32 << 2);
|
fn left_shift_test(e_max: u128, by: u32, output_max: u32) {
|
||||||
uint_test!(
|
let left = e_with_max(e_max);
|
||||||
0xffffffff_u32,
|
|
||||||
true,
|
|
||||||
2,
|
|
||||||
true,
|
|
||||||
left_shift,
|
|
||||||
0xffffffff_u32 << 2
|
|
||||||
);
|
|
||||||
|
|
||||||
// left argument out of range, we reduce and the max is the type max, shifted
|
let right = by;
|
||||||
uint_test!(
|
|
||||||
0xffffffffffff_u128,
|
let left_expected = force_reduce(left.clone());
|
||||||
true,
|
|
||||||
2,
|
let right_expected = right;
|
||||||
true,
|
|
||||||
left_shift,
|
assert_eq!(
|
||||||
0xffffffff_u32 << 2
|
UintOptimizer::new()
|
||||||
)
|
.fold_uint_expression(UExpression::left_shift(left.clone(), right.clone())),
|
||||||
|
UExpression::left_shift(left_expected, right_expected).with_max(output_max)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
left_shift_test(0xff_u128, 2, 0xff << 2);
|
||||||
|
left_shift_test(2, 2, 2 << 2);
|
||||||
|
left_shift_test(0xffffffffffff_u128, 2, 0xffffffff << 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
@ -308,17 +308,15 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
|
|
||||||
UExpressionInner::Or(box left, box right)
|
UExpressionInner::Or(box left, box right)
|
||||||
}
|
}
|
||||||
UExpressionInner::LeftShift(box e, box by) => {
|
UExpressionInner::LeftShift(box e, by) => {
|
||||||
let e = f.fold_uint_expression(e);
|
let e = f.fold_uint_expression(e);
|
||||||
let by = f.fold_uint_expression(by);
|
|
||||||
|
|
||||||
UExpressionInner::LeftShift(box e, box by)
|
UExpressionInner::LeftShift(box e, by)
|
||||||
}
|
}
|
||||||
UExpressionInner::RightShift(box e, box by) => {
|
UExpressionInner::RightShift(box e, by) => {
|
||||||
let e = f.fold_uint_expression(e);
|
let e = f.fold_uint_expression(e);
|
||||||
let by = f.fold_uint_expression(by);
|
|
||||||
|
|
||||||
UExpressionInner::RightShift(box e, box by)
|
UExpressionInner::RightShift(box e, by)
|
||||||
}
|
}
|
||||||
UExpressionInner::Not(box e) => {
|
UExpressionInner::Not(box e) => {
|
||||||
let e = f.fold_uint_expression(e);
|
let e = f.fold_uint_expression(e);
|
||||||
|
|
|
@ -57,16 +57,14 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
||||||
UExpressionInner::And(box self, box other).annotate(bitwidth)
|
UExpressionInner::And(box self, box other).annotate(bitwidth)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn left_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
pub fn left_shift(self, by: u32) -> UExpression<'ast, T> {
|
||||||
let bitwidth = self.bitwidth;
|
let bitwidth = self.bitwidth;
|
||||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
UExpressionInner::LeftShift(box self, by).annotate(bitwidth)
|
||||||
UExpressionInner::LeftShift(box self, box by).annotate(bitwidth)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn right_shift(self, by: UExpression<'ast, T>) -> UExpression<'ast, T> {
|
pub fn right_shift(self, by: u32) -> UExpression<'ast, T> {
|
||||||
let bitwidth = self.bitwidth;
|
let bitwidth = self.bitwidth;
|
||||||
assert_eq!(by.bitwidth(), UBitwidth::B32);
|
UExpressionInner::RightShift(box self, by).annotate(bitwidth)
|
||||||
UExpressionInner::RightShift(box self, box by).annotate(bitwidth)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,8 +168,8 @@ pub enum UExpressionInner<'ast, T> {
|
||||||
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
Xor(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||||
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
And(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||||
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
Or(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||||
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
LeftShift(Box<UExpression<'ast, T>>, u32),
|
||||||
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
RightShift(Box<UExpression<'ast, T>>, u32),
|
||||||
Not(Box<UExpression<'ast, T>>),
|
Not(Box<UExpression<'ast, T>>),
|
||||||
IfElse(
|
IfElse(
|
||||||
Box<BooleanExpression<'ast, T>>,
|
Box<BooleanExpression<'ast, T>>,
|
||||||
|
|
Loading…
Reference in a new issue