1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

extract panics just before flattening, simplify zir, remove redundant checks in code generation

This commit is contained in:
schaeff 2022-08-16 18:53:21 +02:00
parent 8caa6b4720
commit 3fbf63d335
18 changed files with 393 additions and 541 deletions

View file

@ -38,6 +38,7 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
crate::zir::RuntimeError::SourceAssertion(s) => RuntimeError::SourceAssertion(s),
crate::zir::RuntimeError::SelectRangeCheck => RuntimeError::SelectRangeCheck,
crate::zir::RuntimeError::DivisionByZero => RuntimeError::Inverse,
crate::zir::RuntimeError::IncompleteDynamicRange => RuntimeError::LtFinalSum,
}
}
}

View file

@ -188,15 +188,6 @@ impl<T: Field> fmt::Display for Prog<T> {
for s in &self.statements {
writeln!(f, "\t{}", s)?;
}
writeln!(
f,
"\treturn {}",
(0..self.return_count)
.map(Variable::public)
.map(|e| format!("{}", e))
.collect::<Vec<_>>()
.join(", ")
)?;
writeln!(f, "\treturn {}", returns)?;
writeln!(f, "}}")

View file

@ -212,41 +212,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintGe(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
let e2 = f.fold_boolean_expression(e2);

View file

@ -8,7 +8,7 @@ mod uint;
mod variable;
pub use self::parameter::Parameter;
pub use self::types::Type;
pub use self::types::{Type, UBitwidth};
pub use self::variable::Variable;
use crate::common::{FlatEmbed, FormatString};
use crate::typed::ConcreteType;
@ -23,7 +23,7 @@ pub use self::folder::Folder;
pub use self::identifier::{Identifier, SourceIdentifier};
/// A typed program as a collection of modules, one of them being the main
#[derive(PartialEq, Eq, Debug)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct ZirProgram<'ast, T> {
pub main: ZirFunction<'ast, T>,
}
@ -93,6 +93,7 @@ pub enum RuntimeError {
SourceAssertion(String),
SelectRangeCheck,
DivisionByZero,
IncompleteDynamicRange,
}
impl fmt::Display for RuntimeError {
@ -101,6 +102,7 @@ impl fmt::Display for RuntimeError {
RuntimeError::SourceAssertion(message) => write!(f, "{}", message),
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
RuntimeError::IncompleteDynamicRange => write!(f, "Dynamic comparison is incomplete"),
}
}
}
@ -140,14 +142,15 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
write!(f, "{}", "\t".repeat(depth))?;
match self {
ZirStatement::Return(ref exprs) => {
write!(f, "return ")?;
for (i, expr) in exprs.iter().enumerate() {
write!(f, "{}", expr)?;
if i < exprs.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, ";")
write!(
f,
"return {};",
exprs
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
)
}
ZirStatement::Definition(ref lhs, ref rhs) => {
write!(f, "{} = {};", lhs, rhs)
@ -183,7 +186,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
}
ZirStatement::Log(ref l, ref expressions) => write!(
f,
"log(\"{}\"), {})",
"log(\"{}\"), {});",
l,
expressions
.iter()
@ -335,22 +338,12 @@ pub enum BooleanExpression<'ast, T> {
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGe(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldGt(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
FieldEq(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
BoolEq(
Box<BooleanExpression<'ast, T>>,
@ -489,7 +482,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs),
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
UExpressionInner::Not(ref e) => write!(f, "!({})", e),
UExpressionInner::Conditional(ref condition, ref consequent, ref alternative) => {
write!(
f,
@ -515,20 +508,16 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
.join(", "),
i
),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!({})", exp),
BooleanExpression::Conditional(ref condition, ref consequent, ref alternative) => {
write!(
f,

View file

@ -246,41 +246,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldGt(box e1, box e2)
}
BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldGe(box e1, box e2)
}
BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintLt(box e1, box e2)
}
BooleanExpression::FieldLe(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldLe(box e1, box e2)
}
BooleanExpression::UintLe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintLe(box e1, box e2)
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintGt(box e1, box e2)
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintGe(box e1, box e2)
}
BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1)?;
let e2 = f.fold_boolean_expression(e2)?;

View file

@ -9,5 +9,5 @@ def main(field a) {
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1);
bool c = a < maxvalue + 1;
return
return;
}

View file

@ -9,5 +9,5 @@ def main(field a) {
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1);
bool c = maxvalue + 1 < a;
return
return;
}

View file

@ -774,66 +774,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let sub_width = bit_width + 1;
// define variables for the bits
let shifted_sub_bits_be: Vec<Variable> =
(0..sub_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
shifted_sub_bits_be.clone(),
Solver::bits(sub_width),
vec![shifted_sub.clone()],
)));
// bitness checks
for bit in shifted_sub_bits_be.iter() {
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Identifier(*bit),
FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Identifier(*bit),
),
RuntimeError::LtFinalBitness,
));
}
// sum(sym_b{i} * 2**i)
let mut expr = FlatExpression::Number(T::from(0));
for (i, bit) in shifted_sub_bits_be.iter().take(sub_width).enumerate() {
expr = FlatExpression::Add(
box expr,
box FlatExpression::Mult(
box FlatExpression::Identifier(*bit),
box FlatExpression::Number(T::from(2).pow(sub_width - i - 1)),
),
);
}
statements_flattened.push_back(FlatStatement::Condition(
shifted_sub,
expr,
RuntimeError::LtFinalSum,
));
// to make this check symetric, we ban the value `a - b == -2**N`, as the value `a - b == 2**N` is already banned
let fail = self.eq_check(
let shifted_sub_bits_be = self.get_bits_unchecked(
&FlatUExpression::with_field(shifted_sub),
sub_width,
sub_width,
statements_flattened,
FlatExpression::Sub(
box FlatExpression::Identifier(rhs_id),
box FlatExpression::Identifier(lhs_id),
),
FlatExpression::Number(T::from(2).pow(bit_width)),
RuntimeError::LtFinalSum,
);
statements_flattened.push_back(FlatStatement::Condition(
fail,
FlatExpression::Number(T::from(0)),
RuntimeError::LtSymetric,
));
FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(shifted_sub_bits_be[0]),
box shifted_sub_bits_be[0].clone(),
)
}
}
@ -949,14 +900,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::FieldLt(rhs, lhs),
),
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::FieldLe(rhs, lhs),
),
BooleanExpression::UintLt(box lhs, box rhs) => {
let bit_width = lhs.bitwidth.to_usize();
assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool());
@ -987,14 +930,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
FlatExpression::Add(box eq, box lt)
}
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::UintLt(rhs, lhs),
),
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
statements_flattened,
BooleanExpression::UintLe(rhs, lhs),
),
BooleanExpression::Or(box lhs, box rhs) => {
let x = self.flatten_boolean_expression(statements_flattened, lhs);
let y = self.flatten_boolean_expression(statements_flattened, rhs);
@ -1368,23 +1303,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Identifier(id)
};
// first check that the d is not 0 by giving its inverse
let invd = self.use_sym();
// # invd = 1/d
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![invd],
Solver::Div,
vec![FlatExpression::Number(T::one()), d.clone()],
)));
// assert(invd * d == 1)
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::one()),
FlatExpression::Mult(box invd.into(), box d.clone()),
RuntimeError::Inverse,
));
// now introduce the quotient and remainder
let q = self.use_sym();
let r = self.use_sym();
@ -2184,23 +2102,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
id.into()
};
let invb = self.use_sym();
// `right` is assumed to already be non-zero so this is an unchecked division
// TODO: we could save one constraint here by reusing the inverse of `right` computed earlier
let inverse = self.use_sym();
// # invb = 1/b
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![invb],
Solver::Div,
vec![FlatExpression::Number(T::one()), new_right.clone()],
)));
// assert(invb * b == 1)
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::one()),
FlatExpression::Mult(box invb.into(), box new_right.clone()),
RuntimeError::Inverse,
));
// # c = a/b
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
vec![inverse],
@ -2428,8 +2334,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
error.into(),
)
}
BooleanExpression::FieldLt(box lhs, box rhs)
| BooleanExpression::FieldGt(box rhs, box lhs) => {
BooleanExpression::FieldLt(box lhs, box rhs) => {
let lhs = self.flatten_field_expression(statements_flattened, lhs);
let rhs = self.flatten_field_expression(statements_flattened, rhs);
@ -2459,8 +2364,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
}
BooleanExpression::FieldLe(box lhs, box rhs)
| BooleanExpression::FieldGe(box rhs, box lhs) => {
BooleanExpression::FieldLe(box lhs, box rhs) => {
let lhs = self.flatten_field_expression(statements_flattened, lhs);
let rhs = self.flatten_field_expression(statements_flattened, rhs);
@ -2490,6 +2394,40 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
}
BooleanExpression::UintLe(box lhs, box rhs) => {
let lhs = self
.flatten_uint_expression(statements_flattened, lhs)
.get_field_unchecked();
let rhs = self
.flatten_uint_expression(statements_flattened, rhs)
.get_field_unchecked();
match (lhs, rhs) {
(e, FlatExpression::Number(c)) => self.enforce_constant_le_check(
statements_flattened,
e,
c,
error.into(),
),
// c <= e <=> p - 1 - e <= p - 1 - c
(FlatExpression::Number(c), e) => self.enforce_constant_le_check(
statements_flattened,
FlatExpression::Sub(box T::max_value().into(), box e),
T::max_value() - c,
error.into(),
),
(lhs, rhs) => {
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // dynamic comparison is not complete
let e = self.le_check(statements_flattened, lhs, rhs, safe_width);
statements_flattened.push_back(FlatStatement::Condition(
e,
FlatExpression::Number(T::one()),
error.into(),
));
}
}
}
BooleanExpression::UintEq(box lhs, box rhs) => {
let lhs = self
.flatten_uint_expression(statements_flattened, lhs)
@ -2516,6 +2454,74 @@ impl<'ast, T: Field> Flattener<'ast, T> {
error.into(),
)
}
// `!(x == 0)` can be asserted by giving the inverse of `x`
BooleanExpression::Not(box BooleanExpression::UintEq(
box UExpression {
inner: UExpressionInner::Value(0),
..
},
box x,
))
| BooleanExpression::Not(box BooleanExpression::UintEq(
box x,
box UExpression {
inner: UExpressionInner::Value(0),
..
},
)) => {
let x = self
.flatten_uint_expression(statements_flattened, x)
.get_field_unchecked();
// check that `x` is not 0 by giving its inverse
let invx = self.use_sym();
// # invx = 1/x
statements_flattened.push_back(FlatStatement::Directive(
FlatDirective::new(
vec![invx],
Solver::Div,
vec![FlatExpression::Number(T::one()), x.clone()],
),
));
// assert(invx * x == 1)
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::one()),
FlatExpression::Mult(box invx.into(), box x.clone()),
RuntimeError::Inverse,
));
}
// `!(x == 0)` can be asserted by giving the inverse of `x`
BooleanExpression::Not(box BooleanExpression::FieldEq(
box FieldElementExpression::Number(zero),
box x,
))
| BooleanExpression::Not(box BooleanExpression::FieldEq(
box x,
box FieldElementExpression::Number(zero),
)) if zero == T::from(0) => {
let x = self.flatten_field_expression(statements_flattened, x);
// check that `x` is not 0 by giving its inverse
let invx = self.use_sym();
// # invx = 1/x
statements_flattened.push_back(FlatStatement::Directive(
FlatDirective::new(
vec![invx],
Solver::Div,
vec![FlatExpression::Number(T::one()), x.clone()],
),
));
// assert(invx * x == 1)
statements_flattened.push_back(FlatStatement::Condition(
FlatExpression::Number(T::one()),
FlatExpression::Mult(box invx.into(), box x.clone()),
RuntimeError::Inverse,
));
}
e => {
// naive approach: flatten the boolean to a single field element and constrain it to 1
let e = self.flatten_boolean_expression(statements_flattened, e);
@ -3545,24 +3551,6 @@ mod tests {
flattener.flatten_field_expression(&mut FlatStatements::new(), expression);
}
#[test]
fn geq_leq() {
let config = CompileConfig::default();
let mut flattener = Flattener::new(config);
let expression_le = BooleanExpression::FieldLe(
box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le);
let mut flattener = Flattener::new(config);
let expression_ge = BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(32)),
box FieldElementExpression::Number(Bn128Field::from(4)),
);
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_ge);
}
#[test]
fn bool_and() {
let config = CompileConfig::default();
@ -3619,18 +3607,14 @@ mod tests {
// define new wires for members of Div
let five = Variable::new(1);
let b0 = Variable::new(2);
// Define inverse of denominator to prevent div by 0
let invb0 = Variable::new(3);
// Define inverse
let sym_0 = Variable::new(4);
let sym_0 = Variable::new(3);
// Define result, which is first member to next Div
let sym_1 = Variable::new(5);
let sym_1 = Variable::new(4);
// Define second member
let b1 = Variable::new(6);
// Define inverse of denominator to prevent div by 0
let invb1 = Variable::new(7);
let b1 = Variable::new(5);
// Define inverse
let sym_2 = Variable::new(8);
let sym_2 = Variable::new(6);
assert_eq!(
statements_flattened,
@ -3639,17 +3623,6 @@ mod tests {
// inputs to first div (5/b)
FlatStatement::Definition(five, FlatExpression::Number(Bn128Field::from(5))),
FlatStatement::Definition(b0, b.into()),
// check div by 0
FlatStatement::Directive(FlatDirective::new(
vec![invb0],
Solver::Div,
vec![FlatExpression::Number(Bn128Field::from(1)), b0.into()]
)),
FlatStatement::Condition(
FlatExpression::Number(Bn128Field::from(1)),
FlatExpression::Mult(box invb0.into(), box b0.into()),
RuntimeError::Inverse,
),
// execute div
FlatStatement::Directive(FlatDirective::new(
vec![sym_0],
@ -3664,17 +3637,6 @@ mod tests {
// inputs to second div (res/b)
FlatStatement::Definition(sym_1, sym_0.into()),
FlatStatement::Definition(b1, b.into()),
// check div by 0
FlatStatement::Directive(FlatDirective::new(
vec![invb1],
Solver::Div,
vec![FlatExpression::Number(Bn128Field::from(1)), b1.into()]
)),
FlatStatement::Condition(
FlatExpression::Number(Bn128Field::from(1)),
FlatExpression::Mult(box invb1.into(), box b1.into()),
RuntimeError::Inverse
),
// execute div
FlatStatement::Directive(FlatDirective::new(
vec![sym_2],

View file

@ -67,7 +67,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> {
let condition_id = Identifier::from(CoreIdentifier::Condition(self.index));
self.buffer.push(TypedStatement::definition(
Variable::immutable(condition_id.clone(), Type::Boolean).into(),
TypedExpression::from(condition).into(),
TypedExpression::from(condition),
));
self.index += 1;
BooleanExpression::Identifier(condition_id)

View file

@ -1,24 +1,20 @@
use std::collections::HashSet;
use zokrates_ast::typed::{
folder::*, BlockExpression, Identifier, TypedAssignee, TypedFunction, TypedProgram,
TypedStatement,
};
use zokrates_ast::zir::{folder::*, Identifier, ZirFunction, ZirProgram, ZirStatement};
use zokrates_field::Field;
#[derive(Default)]
pub struct DeadCodeEliminator<'ast> {
used: HashSet<Identifier<'ast>>,
in_block: usize,
}
impl<'ast> DeadCodeEliminator<'ast> {
pub fn eliminate<T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
pub fn eliminate<T: Field>(p: ZirProgram<'ast, T>) -> ZirProgram<'ast, T> {
Self::default().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
fn fold_function(&mut self, f: ZirFunction<'ast, T>) -> ZirFunction<'ast, T> {
// iterate on the statements starting from the end, as we want to see usage before definition
let mut statements: Vec<_> = f
.statements
@ -27,51 +23,44 @@ impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> {
.flat_map(|s| self.fold_statement(s))
.collect();
statements.reverse();
TypedFunction { statements, ..f }
ZirFunction { statements, ..f }
}
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> {
match s {
TypedStatement::Definition(a, e) => match a {
TypedAssignee::Identifier(ref id) => {
// if the lhs is used later in the program and we're in a block
if self.used.remove(&id.id) {
// include this statement
fold_statement(self, TypedStatement::Definition(a, e))
} else {
// otherwise remove it
vec![]
}
ZirStatement::Definition(v, e) => {
// if the lhs is used later in the program and we're in a block
if self.used.remove(&v.id) {
// include this statement
fold_statement(self, ZirStatement::Definition(v, e))
} else {
// otherwise remove it
vec![]
}
_ => fold_statement(self, TypedStatement::Definition(a, e)),
},
TypedStatement::For(..) => {
unreachable!("for loops should be removed before dead code elimination is run")
}
ZirStatement::IfElse(condition, consequence, alternative) => {
let condition = self.fold_boolean_expression(condition);
let mut consequence: Vec<_> = consequence
.into_iter()
.rev()
.flat_map(|e| self.fold_statement(e))
.collect();
consequence.reverse();
let mut alternative: Vec<_> = alternative
.into_iter()
.rev()
.flat_map(|e| self.fold_statement(e))
.collect();
alternative.reverse();
vec![ZirStatement::IfElse(condition, consequence, alternative)]
}
s => fold_statement(self, s),
}
}
fn fold_block_expression<E: Fold<'ast, T>>(
&mut self,
block: BlockExpression<'ast, T, E>,
) -> BlockExpression<'ast, T, E> {
self.in_block += 1;
let value = box block.value.fold(self);
let mut statements: Vec<_> = block
.statements
.into_iter()
.rev()
.flat_map(|s| self.fold_statement(s))
.collect();
statements.reverse();
let block = BlockExpression { value, statements };
self.in_block -= 1;
block
}
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
self.used.insert(n.clone());
n

View file

@ -991,12 +991,12 @@ fn fold_boolean_expression<'ast, T: Field>(
typed::BooleanExpression::FieldGt(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);
zir::BooleanExpression::FieldGt(box e1, box e2)
zir::BooleanExpression::FieldLt(box e2, box e1)
}
typed::BooleanExpression::FieldGe(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);
zir::BooleanExpression::FieldGe(box e1, box e2)
zir::BooleanExpression::FieldLe(box e2, box e1)
}
typed::BooleanExpression::UintLt(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
@ -1011,12 +1011,12 @@ fn fold_boolean_expression<'ast, T: Field>(
typed::BooleanExpression::UintGt(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
let e2 = f.fold_uint_expression(statements_buffer, e2);
zir::BooleanExpression::UintGt(box e1, box e2)
zir::BooleanExpression::UintLt(box e2, box e1)
}
typed::BooleanExpression::UintGe(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
let e2 = f.fold_uint_expression(statements_buffer, e2);
zir::BooleanExpression::UintGe(box e1, box e2)
zir::BooleanExpression::UintLe(box e2, box e1)
}
typed::BooleanExpression::Or(box e1, box e2) => {
let e1 = f.fold_boolean_expression(statements_buffer, e1);

View file

@ -166,14 +166,6 @@ pub fn analyse<'ast, T: Field>(
let r = ConditionRedefiner::redefine(r);
log::trace!("\n{}", r);
log::debug!("Static analyser: Extract panics");
let r = PanicExtractor::extract(r);
log::trace!("\n{}", r);
log::debug!("Static analyser: Remove dead code");
let r = DeadCodeEliminator::eliminate(r);
log::trace!("\n{}", r);
// convert to zir, removing complex types
log::debug!("Static analyser: Convert to zir");
let zir = Flattener::flatten(r);
@ -184,6 +176,14 @@ pub fn analyse<'ast, T: Field>(
let zir = ZirPropagator::propagate(zir).map_err(Error::from)?;
log::trace!("\n{}", zir);
log::debug!("Static analyser: Extract panics");
let zir = PanicExtractor::extract(zir);
log::trace!("\n{}", zir);
log::debug!("Static analyser: Remove dead code");
let zir = DeadCodeEliminator::eliminate(zir);
log::trace!("\n{}", zir);
// optimize uint expressions
log::debug!("Static analyser: Optimize uints");
let zir = UintOptimizer::optimize(zir);

View file

@ -1,6 +1,6 @@
use zokrates_ast::typed::{
folder::*, BooleanExpression, EqExpression, FieldElementExpression, RuntimeError, TypedProgram,
TypedStatement, UBitwidth, UExpressionInner,
use zokrates_ast::zir::{
folder::*, BooleanExpression, FieldElementExpression, RuntimeError, UBitwidth, UExpression,
UExpressionInner, ZirProgram, ZirStatement,
};
use zokrates_field::Field;
@ -8,23 +8,47 @@ use zokrates_field::Field;
#[derive(Default)]
pub struct PanicExtractor<'ast, T> {
panic_buffer: Vec<(BooleanExpression<'ast, T>, RuntimeError)>,
panic_buffer: Vec<ZirStatement<'ast, T>>,
}
impl<'ast, T: Field> PanicExtractor<'ast, T> {
pub fn extract(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
pub fn extract(p: ZirProgram<'ast, T>) -> ZirProgram<'ast, T> {
Self::default().fold_program(p)
}
}
impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
let s = fold_statement(self, s);
self.panic_buffer
.drain(..)
.map(|(b, e)| TypedStatement::Assertion(b, e))
.chain(s)
.collect()
fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> {
match s {
ZirStatement::IfElse(condition, consequence, alternative) => {
let condition = self.fold_boolean_expression(condition);
let mut consequence_extractor = Self::default();
let consequence = consequence
.into_iter()
.flat_map(|s| consequence_extractor.fold_statement(s))
.collect();
assert!(consequence_extractor.panic_buffer.is_empty());
let mut alternative_extractor = Self::default();
let alternative = alternative
.into_iter()
.flat_map(|s| alternative_extractor.fold_statement(s))
.collect();
assert!(alternative_extractor.panic_buffer.is_empty());
self.panic_buffer
.drain(..)
.chain(std::iter::once(ZirStatement::IfElse(
condition,
consequence,
alternative,
)))
.collect()
}
s => {
let s = fold_statement(self, s);
self.panic_buffer.drain(..).chain(s).collect()
}
}
}
fn fold_field_expression(
@ -35,15 +59,34 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
FieldElementExpression::Div(box n, box d) => {
let n = self.fold_field_expression(n);
let d = self.fold_field_expression(d);
self.panic_buffer.push((
BooleanExpression::Not(box BooleanExpression::FieldEq(EqExpression::new(
d.clone(),
T::zero().into(),
))),
self.panic_buffer.push(ZirStatement::Assertion(
BooleanExpression::Not(box BooleanExpression::FieldEq(
box d.clone(),
box FieldElementExpression::Number(T::zero()),
)),
RuntimeError::DivisionByZero,
));
FieldElementExpression::Div(box n, box d)
}
FieldElementExpression::Conditional(
box condition,
box consequence,
box alternative,
) => {
let condition = self.fold_boolean_expression(condition);
let mut consequence_extractor = Self::default();
let consequence = consequence_extractor.fold_field_expression(consequence);
let mut alternative_extractor = Self::default();
let alternative = alternative_extractor.fold_field_expression(alternative);
self.panic_buffer.push(ZirStatement::IfElse(
condition.clone(),
consequence_extractor.panic_buffer.drain(..).collect(),
alternative_extractor.panic_buffer.drain(..).collect(),
));
FieldElementExpression::Conditional(box condition, box consequence, box alternative)
}
e => fold_field_expression(self, e),
}
}
@ -57,11 +100,11 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
UExpressionInner::Div(box n, box d) => {
let n = self.fold_uint_expression(n);
let d = self.fold_uint_expression(d);
self.panic_buffer.push((
BooleanExpression::Not(box BooleanExpression::UintEq(EqExpression::new(
d.clone(),
UExpressionInner::Value(0).annotate(b),
))),
self.panic_buffer.push(ZirStatement::Assertion(
BooleanExpression::Not(box BooleanExpression::UintEq(
box d.clone(),
box UExpressionInner::Value(0).annotate(b),
)),
RuntimeError::DivisionByZero,
));
UExpressionInner::Div(box n, box d)
@ -69,4 +112,62 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
e => fold_uint_expression_inner(self, b, e),
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
) -> BooleanExpression<'ast, T> {
match e {
e @ BooleanExpression::FieldLt(box FieldElementExpression::Number(_), _)
| e @ BooleanExpression::FieldLt(_, box FieldElementExpression::Number(_))
| e @ BooleanExpression::UintLt(
box UExpression {
inner: UExpressionInner::Value(_),
..
},
_,
)
| e @ BooleanExpression::UintLt(
_,
box UExpression {
inner: UExpressionInner::Value(_),
..
},
) => fold_boolean_expression(self, e),
BooleanExpression::FieldLt(box left, box right) => {
let left = self.fold_field_expression(left);
let right = self.fold_field_expression(right);
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2)
let offset = FieldElementExpression::Number(T::from(2).pow(safe_width));
let max = FieldElementExpression::Number(T::from(2).pow(safe_width + 1));
self.panic_buffer.push(ZirStatement::Assertion(
BooleanExpression::And(
box BooleanExpression::Not(box BooleanExpression::FieldEq(
box FieldElementExpression::Sub(box left.clone(), box right.clone()),
box offset.clone(),
)),
box BooleanExpression::FieldLt(
box FieldElementExpression::Add(
box offset,
box FieldElementExpression::Sub(
box right.clone(),
box left.clone(),
),
),
box max,
),
),
RuntimeError::IncompleteDynamicRange,
));
BooleanExpression::FieldLt(box left, box right)
}
e => fold_boolean_expression(self, e),
}
}
}

View file

@ -554,9 +554,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let e_str = e.to_string();
let expr = self.fold_boolean_expression(e)?;
match expr {
BooleanExpression::Value(v) if !v => {
BooleanExpression::Value(false) => {
Err(Error::AssertionFailed(format!("{}: ({})", ty, e_str)))
}
BooleanExpression::Value(true) => Ok(vec![]),
_ => Ok(vec![TypedStatement::Assertion(expr, ty)]),
}
}

View file

@ -114,24 +114,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
BooleanExpression::UintLe(box left, box right)
}
BooleanExpression::UintGt(box left, box right) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left = force_reduce(left);
let right = force_reduce(right);
BooleanExpression::UintGt(box left, box right)
}
BooleanExpression::UintGe(box left, box right) => {
let left = self.fold_uint_expression(left);
let right = self.fold_uint_expression(right);
let left = force_reduce(left);
let right = force_reduce(right);
BooleanExpression::UintGe(box left, box right)
}
e => fold_boolean_expression(self, e),
}
}

View file

@ -317,34 +317,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
(e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)),
}
}
BooleanExpression::FieldGe(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
self.fold_field_expression(e2)?,
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
Ok(BooleanExpression::Value(n1 >= n2))
}
(e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)),
}
}
BooleanExpression::FieldGt(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
self.fold_field_expression(e2)?,
) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
Ok(BooleanExpression::Value(n1 > n2))
}
(_, FieldElementExpression::Number(c)) if c == T::max_value() => {
Ok(BooleanExpression::Value(false))
}
(FieldElementExpression::Number(c), _) if c == T::zero() => {
Ok(BooleanExpression::Value(false))
}
(e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)),
}
}
BooleanExpression::FieldEq(box e1, box e2) => {
match (
self.fold_field_expression(e1)?,
@ -384,28 +356,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
_ => Ok(BooleanExpression::UintLe(box e1, box e2)),
}
}
BooleanExpression::UintGe(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(BooleanExpression::Value(v1 >= v2))
}
_ => Ok(BooleanExpression::UintGe(box e1, box e2)),
}
}
BooleanExpression::UintGt(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(BooleanExpression::Value(v1 > v2))
}
_ => Ok(BooleanExpression::UintGt(box e1, box e2)),
}
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
@ -1019,85 +969,6 @@ mod tests {
);
}
#[test]
fn field_le() {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldLe(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(3)),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldLe(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(3)),
)),
Ok(BooleanExpression::Value(true))
);
}
#[test]
fn field_ge() {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(3)),
)),
Ok(BooleanExpression::Value(true))
);
}
#[test]
fn field_gt() {
let mut propagator = ZirPropagator::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(2)),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(3)),
box FieldElementExpression::Number(Bn128Field::from(3)),
)),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Number(Bn128Field::from(0)),
box FieldElementExpression::Identifier("a".into()),
)),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
box FieldElementExpression::Identifier("a".into()),
box FieldElementExpression::Number(Bn128Field::max_value()),
)),
Ok(BooleanExpression::Value(false))
);
}
#[test]
fn field_eq() {
let mut propagator = ZirPropagator::default();
@ -1140,69 +1011,6 @@ mod tests {
);
}
#[test]
fn uint_le() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintLe(
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintLe(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
}
#[test]
fn uint_ge() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGe(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGe(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
}
#[test]
fn uint_gt() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGt(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
propagator.fold_boolean_expression(BooleanExpression::UintGt(
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
)),
Ok(BooleanExpression::Value(false))
);
}
#[test]
fn uint_eq() {
let mut propagator = ZirPropagator::<Bn128Field>::default();

View file

@ -0,0 +1,65 @@
{
"entry_point": "./tests/tests/div.zok",
"max_constraint_count": 3,
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
"tests": [
{
"input": {
"values": ["0", "0"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "4",
"right": "2",
"error": "Inverse"
}
}
}
},
{
"input": {
"values": ["1", "0"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"left": "4",
"right": "2",
"error": "Inverse"
}
}
}
},
{
"input": {
"values": ["0", "1"]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": ["2", "2"]
},
"output": {
"Ok": {
"value": "1"
}
}
},
{
"input": {
"values": ["4", "2"]
},
"output": {
"Ok": {
"value": "2"
}
}
}
]
}

View file

@ -0,0 +1,3 @@
def main(field x, field y) -> field {
return x / y;
}