extract panics just before flattening, simplify zir, remove redundant checks in code generation
This commit is contained in:
parent
8caa6b4720
commit
3fbf63d335
18 changed files with 393 additions and 541 deletions
|
@ -38,6 +38,7 @@ impl From<crate::zir::RuntimeError> for RuntimeError {
|
|||
crate::zir::RuntimeError::SourceAssertion(s) => RuntimeError::SourceAssertion(s),
|
||||
crate::zir::RuntimeError::SelectRangeCheck => RuntimeError::SelectRangeCheck,
|
||||
crate::zir::RuntimeError::DivisionByZero => RuntimeError::Inverse,
|
||||
crate::zir::RuntimeError::IncompleteDynamicRange => RuntimeError::LtFinalSum,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -188,15 +188,6 @@ impl<T: Field> fmt::Display for Prog<T> {
|
|||
for s in &self.statements {
|
||||
writeln!(f, "\t{}", s)?;
|
||||
}
|
||||
writeln!(
|
||||
f,
|
||||
"\treturn {}",
|
||||
(0..self.return_count)
|
||||
.map(Variable::public)
|
||||
.map(|e| format!("{}", e))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)?;
|
||||
|
||||
writeln!(f, "\treturn {}", returns)?;
|
||||
writeln!(f, "}}")
|
||||
|
|
|
@ -212,41 +212,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
let e2 = f.fold_boolean_expression(e2);
|
||||
|
|
|
@ -8,7 +8,7 @@ mod uint;
|
|||
mod variable;
|
||||
|
||||
pub use self::parameter::Parameter;
|
||||
pub use self::types::Type;
|
||||
pub use self::types::{Type, UBitwidth};
|
||||
pub use self::variable::Variable;
|
||||
use crate::common::{FlatEmbed, FormatString};
|
||||
use crate::typed::ConcreteType;
|
||||
|
@ -23,7 +23,7 @@ pub use self::folder::Folder;
|
|||
pub use self::identifier::{Identifier, SourceIdentifier};
|
||||
|
||||
/// A typed program as a collection of modules, one of them being the main
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
#[derive(PartialEq, Eq, Debug, Clone)]
|
||||
pub struct ZirProgram<'ast, T> {
|
||||
pub main: ZirFunction<'ast, T>,
|
||||
}
|
||||
|
@ -93,6 +93,7 @@ pub enum RuntimeError {
|
|||
SourceAssertion(String),
|
||||
SelectRangeCheck,
|
||||
DivisionByZero,
|
||||
IncompleteDynamicRange,
|
||||
}
|
||||
|
||||
impl fmt::Display for RuntimeError {
|
||||
|
@ -101,6 +102,7 @@ impl fmt::Display for RuntimeError {
|
|||
RuntimeError::SourceAssertion(message) => write!(f, "{}", message),
|
||||
RuntimeError::SelectRangeCheck => write!(f, "Range check on array access"),
|
||||
RuntimeError::DivisionByZero => write!(f, "Division by zero"),
|
||||
RuntimeError::IncompleteDynamicRange => write!(f, "Dynamic comparison is incomplete"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -140,14 +142,15 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
match self {
|
||||
ZirStatement::Return(ref exprs) => {
|
||||
write!(f, "return ")?;
|
||||
for (i, expr) in exprs.iter().enumerate() {
|
||||
write!(f, "{}", expr)?;
|
||||
if i < exprs.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ";")
|
||||
write!(
|
||||
f,
|
||||
"return {};",
|
||||
exprs
|
||||
.iter()
|
||||
.map(|e| e.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
}
|
||||
ZirStatement::Definition(ref lhs, ref rhs) => {
|
||||
write!(f, "{} = {};", lhs, rhs)
|
||||
|
@ -183,7 +186,7 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
}
|
||||
ZirStatement::Log(ref l, ref expressions) => write!(
|
||||
f,
|
||||
"log(\"{}\"), {})",
|
||||
"log(\"{}\"), {});",
|
||||
l,
|
||||
expressions
|
||||
.iter()
|
||||
|
@ -335,22 +338,12 @@ pub enum BooleanExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGe(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldGt(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
FieldEq(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
UintLt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
BoolEq(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
|
@ -489,7 +482,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
|||
UExpressionInner::Or(ref lhs, ref rhs) => write!(f, "({} | {})", lhs, rhs),
|
||||
UExpressionInner::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
||||
UExpressionInner::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
||||
UExpressionInner::Not(ref e) => write!(f, "!{}", e),
|
||||
UExpressionInner::Not(ref e) => write!(f, "!({})", e),
|
||||
UExpressionInner::Conditional(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
|
@ -515,20 +508,16 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
.join(", "),
|
||||
i
|
||||
),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::FieldGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs),
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||
BooleanExpression::FieldLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
|
||||
BooleanExpression::UintLt(ref lhs, ref rhs) => write!(f, "({} < {})", lhs, rhs),
|
||||
BooleanExpression::FieldLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
|
||||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "({} <= {})", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
|
||||
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "({} == {})", lhs, rhs),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "({} || {})", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "({} && {})", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!({})", exp),
|
||||
BooleanExpression::Conditional(ref condition, ref consequent, ref alternative) => {
|
||||
write!(
|
||||
f,
|
||||
|
|
|
@ -246,41 +246,21 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintLt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldLe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintLe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintLe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintGt(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintGe(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1)?;
|
||||
let e2 = f.fold_boolean_expression(e2)?;
|
||||
|
|
|
@ -9,5 +9,5 @@ def main(field a) {
|
|||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
field maxvalue = a + (2**(pbits - 2) - 1);
|
||||
bool c = a < maxvalue + 1;
|
||||
return
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -9,5 +9,5 @@ def main(field a) {
|
|||
// we added a = 0 to prevent the condition to be evaluated at compile time
|
||||
field maxvalue = a + (2**(pbits - 2) - 1);
|
||||
bool c = maxvalue + 1 < a;
|
||||
return
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -774,66 +774,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
let sub_width = bit_width + 1;
|
||||
|
||||
// define variables for the bits
|
||||
let shifted_sub_bits_be: Vec<Variable> =
|
||||
(0..sub_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
shifted_sub_bits_be.clone(),
|
||||
Solver::bits(sub_width),
|
||||
vec![shifted_sub.clone()],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for bit in shifted_sub_bits_be.iter() {
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(*bit),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Identifier(*bit),
|
||||
),
|
||||
RuntimeError::LtFinalBitness,
|
||||
));
|
||||
}
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for (i, bit) in shifted_sub_bits_be.iter().take(sub_width).enumerate() {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(*bit),
|
||||
box FlatExpression::Number(T::from(2).pow(sub_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
shifted_sub,
|
||||
expr,
|
||||
RuntimeError::LtFinalSum,
|
||||
));
|
||||
|
||||
// to make this check symetric, we ban the value `a - b == -2**N`, as the value `a - b == 2**N` is already banned
|
||||
let fail = self.eq_check(
|
||||
let shifted_sub_bits_be = self.get_bits_unchecked(
|
||||
&FlatUExpression::with_field(shifted_sub),
|
||||
sub_width,
|
||||
sub_width,
|
||||
statements_flattened,
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Identifier(rhs_id),
|
||||
box FlatExpression::Identifier(lhs_id),
|
||||
),
|
||||
FlatExpression::Number(T::from(2).pow(bit_width)),
|
||||
RuntimeError::LtFinalSum,
|
||||
);
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
fail,
|
||||
FlatExpression::Number(T::from(0)),
|
||||
RuntimeError::LtSymetric,
|
||||
));
|
||||
|
||||
FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(shifted_sub_bits_be[0]),
|
||||
box shifted_sub_bits_be[0].clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -949,14 +900,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
FlatExpression::Add(box eq, box lt)
|
||||
}
|
||||
BooleanExpression::FieldGt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLt(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::FieldGe(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::FieldLe(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::UintLt(box lhs, box rhs) => {
|
||||
let bit_width = lhs.bitwidth.to_usize();
|
||||
assert!(lhs.metadata.as_ref().unwrap().should_reduce.to_bool());
|
||||
|
@ -987,14 +930,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
);
|
||||
FlatExpression::Add(box eq, box lt)
|
||||
}
|
||||
BooleanExpression::UintGt(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLt(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::UintGe(lhs, rhs) => self.flatten_boolean_expression(
|
||||
statements_flattened,
|
||||
BooleanExpression::UintLe(rhs, lhs),
|
||||
),
|
||||
BooleanExpression::Or(box lhs, box rhs) => {
|
||||
let x = self.flatten_boolean_expression(statements_flattened, lhs);
|
||||
let y = self.flatten_boolean_expression(statements_flattened, rhs);
|
||||
|
@ -1368,23 +1303,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Identifier(id)
|
||||
};
|
||||
|
||||
// first check that the d is not 0 by giving its inverse
|
||||
let invd = self.use_sym();
|
||||
|
||||
// # invd = 1/d
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![invd],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(T::one()), d.clone()],
|
||||
)));
|
||||
|
||||
// assert(invd * d == 1)
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::one()),
|
||||
FlatExpression::Mult(box invd.into(), box d.clone()),
|
||||
RuntimeError::Inverse,
|
||||
));
|
||||
|
||||
// now introduce the quotient and remainder
|
||||
let q = self.use_sym();
|
||||
let r = self.use_sym();
|
||||
|
@ -2184,23 +2102,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
id.into()
|
||||
};
|
||||
|
||||
let invb = self.use_sym();
|
||||
// `right` is assumed to already be non-zero so this is an unchecked division
|
||||
// TODO: we could save one constraint here by reusing the inverse of `right` computed earlier
|
||||
|
||||
let inverse = self.use_sym();
|
||||
|
||||
// # invb = 1/b
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![invb],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(T::one()), new_right.clone()],
|
||||
)));
|
||||
|
||||
// assert(invb * b == 1)
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::one()),
|
||||
FlatExpression::Mult(box invb.into(), box new_right.clone()),
|
||||
RuntimeError::Inverse,
|
||||
));
|
||||
|
||||
// # c = a/b
|
||||
statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![inverse],
|
||||
|
@ -2428,8 +2334,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
error.into(),
|
||||
)
|
||||
}
|
||||
BooleanExpression::FieldLt(box lhs, box rhs)
|
||||
| BooleanExpression::FieldGt(box rhs, box lhs) => {
|
||||
BooleanExpression::FieldLt(box lhs, box rhs) => {
|
||||
let lhs = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
|
@ -2459,8 +2364,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldLe(box lhs, box rhs)
|
||||
| BooleanExpression::FieldGe(box rhs, box lhs) => {
|
||||
BooleanExpression::FieldLe(box lhs, box rhs) => {
|
||||
let lhs = self.flatten_field_expression(statements_flattened, lhs);
|
||||
let rhs = self.flatten_field_expression(statements_flattened, rhs);
|
||||
|
||||
|
@ -2490,6 +2394,40 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
}
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintLe(box lhs, box rhs) => {
|
||||
let lhs = self
|
||||
.flatten_uint_expression(statements_flattened, lhs)
|
||||
.get_field_unchecked();
|
||||
let rhs = self
|
||||
.flatten_uint_expression(statements_flattened, rhs)
|
||||
.get_field_unchecked();
|
||||
|
||||
match (lhs, rhs) {
|
||||
(e, FlatExpression::Number(c)) => self.enforce_constant_le_check(
|
||||
statements_flattened,
|
||||
e,
|
||||
c,
|
||||
error.into(),
|
||||
),
|
||||
// c <= e <=> p - 1 - e <= p - 1 - c
|
||||
(FlatExpression::Number(c), e) => self.enforce_constant_le_check(
|
||||
statements_flattened,
|
||||
FlatExpression::Sub(box T::max_value().into(), box e),
|
||||
T::max_value() - c,
|
||||
error.into(),
|
||||
),
|
||||
(lhs, rhs) => {
|
||||
let bit_width = T::get_required_bits();
|
||||
let safe_width = bit_width - 2; // dynamic comparison is not complete
|
||||
let e = self.le_check(statements_flattened, lhs, rhs, safe_width);
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
e,
|
||||
FlatExpression::Number(T::one()),
|
||||
error.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintEq(box lhs, box rhs) => {
|
||||
let lhs = self
|
||||
.flatten_uint_expression(statements_flattened, lhs)
|
||||
|
@ -2516,6 +2454,74 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
error.into(),
|
||||
)
|
||||
}
|
||||
// `!(x == 0)` can be asserted by giving the inverse of `x`
|
||||
BooleanExpression::Not(box BooleanExpression::UintEq(
|
||||
box UExpression {
|
||||
inner: UExpressionInner::Value(0),
|
||||
..
|
||||
},
|
||||
box x,
|
||||
))
|
||||
| BooleanExpression::Not(box BooleanExpression::UintEq(
|
||||
box x,
|
||||
box UExpression {
|
||||
inner: UExpressionInner::Value(0),
|
||||
..
|
||||
},
|
||||
)) => {
|
||||
let x = self
|
||||
.flatten_uint_expression(statements_flattened, x)
|
||||
.get_field_unchecked();
|
||||
|
||||
// check that `x` is not 0 by giving its inverse
|
||||
let invx = self.use_sym();
|
||||
|
||||
// # invx = 1/x
|
||||
statements_flattened.push_back(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
vec![invx],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(T::one()), x.clone()],
|
||||
),
|
||||
));
|
||||
|
||||
// assert(invx * x == 1)
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::one()),
|
||||
FlatExpression::Mult(box invx.into(), box x.clone()),
|
||||
RuntimeError::Inverse,
|
||||
));
|
||||
}
|
||||
// `!(x == 0)` can be asserted by giving the inverse of `x`
|
||||
BooleanExpression::Not(box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(zero),
|
||||
box x,
|
||||
))
|
||||
| BooleanExpression::Not(box BooleanExpression::FieldEq(
|
||||
box x,
|
||||
box FieldElementExpression::Number(zero),
|
||||
)) if zero == T::from(0) => {
|
||||
let x = self.flatten_field_expression(statements_flattened, x);
|
||||
|
||||
// check that `x` is not 0 by giving its inverse
|
||||
let invx = self.use_sym();
|
||||
|
||||
// # invx = 1/x
|
||||
statements_flattened.push_back(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
vec![invx],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(T::one()), x.clone()],
|
||||
),
|
||||
));
|
||||
|
||||
// assert(invx * x == 1)
|
||||
statements_flattened.push_back(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::one()),
|
||||
FlatExpression::Mult(box invx.into(), box x.clone()),
|
||||
RuntimeError::Inverse,
|
||||
));
|
||||
}
|
||||
e => {
|
||||
// naive approach: flatten the boolean to a single field element and constrain it to 1
|
||||
let e = self.flatten_boolean_expression(statements_flattened, e);
|
||||
|
@ -3545,24 +3551,6 @@ mod tests {
|
|||
flattener.flatten_field_expression(&mut FlatStatements::new(), expression);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn geq_leq() {
|
||||
let config = CompileConfig::default();
|
||||
let mut flattener = Flattener::new(config);
|
||||
let expression_le = BooleanExpression::FieldLe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(32)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_le);
|
||||
|
||||
let mut flattener = Flattener::new(config);
|
||||
let expression_ge = BooleanExpression::FieldGe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(32)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
);
|
||||
flattener.flatten_boolean_expression(&mut FlatStatements::new(), expression_ge);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bool_and() {
|
||||
let config = CompileConfig::default();
|
||||
|
@ -3619,18 +3607,14 @@ mod tests {
|
|||
// define new wires for members of Div
|
||||
let five = Variable::new(1);
|
||||
let b0 = Variable::new(2);
|
||||
// Define inverse of denominator to prevent div by 0
|
||||
let invb0 = Variable::new(3);
|
||||
// Define inverse
|
||||
let sym_0 = Variable::new(4);
|
||||
let sym_0 = Variable::new(3);
|
||||
// Define result, which is first member to next Div
|
||||
let sym_1 = Variable::new(5);
|
||||
let sym_1 = Variable::new(4);
|
||||
// Define second member
|
||||
let b1 = Variable::new(6);
|
||||
// Define inverse of denominator to prevent div by 0
|
||||
let invb1 = Variable::new(7);
|
||||
let b1 = Variable::new(5);
|
||||
// Define inverse
|
||||
let sym_2 = Variable::new(8);
|
||||
let sym_2 = Variable::new(6);
|
||||
|
||||
assert_eq!(
|
||||
statements_flattened,
|
||||
|
@ -3639,17 +3623,6 @@ mod tests {
|
|||
// inputs to first div (5/b)
|
||||
FlatStatement::Definition(five, FlatExpression::Number(Bn128Field::from(5))),
|
||||
FlatStatement::Definition(b0, b.into()),
|
||||
// check div by 0
|
||||
FlatStatement::Directive(FlatDirective::new(
|
||||
vec![invb0],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(Bn128Field::from(1)), b0.into()]
|
||||
)),
|
||||
FlatStatement::Condition(
|
||||
FlatExpression::Number(Bn128Field::from(1)),
|
||||
FlatExpression::Mult(box invb0.into(), box b0.into()),
|
||||
RuntimeError::Inverse,
|
||||
),
|
||||
// execute div
|
||||
FlatStatement::Directive(FlatDirective::new(
|
||||
vec![sym_0],
|
||||
|
@ -3664,17 +3637,6 @@ mod tests {
|
|||
// inputs to second div (res/b)
|
||||
FlatStatement::Definition(sym_1, sym_0.into()),
|
||||
FlatStatement::Definition(b1, b.into()),
|
||||
// check div by 0
|
||||
FlatStatement::Directive(FlatDirective::new(
|
||||
vec![invb1],
|
||||
Solver::Div,
|
||||
vec![FlatExpression::Number(Bn128Field::from(1)), b1.into()]
|
||||
)),
|
||||
FlatStatement::Condition(
|
||||
FlatExpression::Number(Bn128Field::from(1)),
|
||||
FlatExpression::Mult(box invb1.into(), box b1.into()),
|
||||
RuntimeError::Inverse
|
||||
),
|
||||
// execute div
|
||||
FlatStatement::Directive(FlatDirective::new(
|
||||
vec![sym_2],
|
||||
|
|
|
@ -67,7 +67,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ConditionRedefiner<'ast, T> {
|
|||
let condition_id = Identifier::from(CoreIdentifier::Condition(self.index));
|
||||
self.buffer.push(TypedStatement::definition(
|
||||
Variable::immutable(condition_id.clone(), Type::Boolean).into(),
|
||||
TypedExpression::from(condition).into(),
|
||||
TypedExpression::from(condition),
|
||||
));
|
||||
self.index += 1;
|
||||
BooleanExpression::Identifier(condition_id)
|
||||
|
|
|
@ -1,24 +1,20 @@
|
|||
use std::collections::HashSet;
|
||||
use zokrates_ast::typed::{
|
||||
folder::*, BlockExpression, Identifier, TypedAssignee, TypedFunction, TypedProgram,
|
||||
TypedStatement,
|
||||
};
|
||||
use zokrates_ast::zir::{folder::*, Identifier, ZirFunction, ZirProgram, ZirStatement};
|
||||
use zokrates_field::Field;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DeadCodeEliminator<'ast> {
|
||||
used: HashSet<Identifier<'ast>>,
|
||||
in_block: usize,
|
||||
}
|
||||
|
||||
impl<'ast> DeadCodeEliminator<'ast> {
|
||||
pub fn eliminate<T: Field>(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
pub fn eliminate<T: Field>(p: ZirProgram<'ast, T>) -> ZirProgram<'ast, T> {
|
||||
Self::default().fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> {
|
||||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
fn fold_function(&mut self, f: ZirFunction<'ast, T>) -> ZirFunction<'ast, T> {
|
||||
// iterate on the statements starting from the end, as we want to see usage before definition
|
||||
let mut statements: Vec<_> = f
|
||||
.statements
|
||||
|
@ -27,51 +23,44 @@ impl<'ast, T: Field> Folder<'ast, T> for DeadCodeEliminator<'ast> {
|
|||
.flat_map(|s| self.fold_statement(s))
|
||||
.collect();
|
||||
statements.reverse();
|
||||
TypedFunction { statements, ..f }
|
||||
ZirFunction { statements, ..f }
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> {
|
||||
match s {
|
||||
TypedStatement::Definition(a, e) => match a {
|
||||
TypedAssignee::Identifier(ref id) => {
|
||||
// if the lhs is used later in the program and we're in a block
|
||||
if self.used.remove(&id.id) {
|
||||
// include this statement
|
||||
fold_statement(self, TypedStatement::Definition(a, e))
|
||||
} else {
|
||||
// otherwise remove it
|
||||
vec![]
|
||||
}
|
||||
ZirStatement::Definition(v, e) => {
|
||||
// if the lhs is used later in the program and we're in a block
|
||||
if self.used.remove(&v.id) {
|
||||
// include this statement
|
||||
fold_statement(self, ZirStatement::Definition(v, e))
|
||||
} else {
|
||||
// otherwise remove it
|
||||
vec![]
|
||||
}
|
||||
_ => fold_statement(self, TypedStatement::Definition(a, e)),
|
||||
},
|
||||
TypedStatement::For(..) => {
|
||||
unreachable!("for loops should be removed before dead code elimination is run")
|
||||
}
|
||||
ZirStatement::IfElse(condition, consequence, alternative) => {
|
||||
let condition = self.fold_boolean_expression(condition);
|
||||
|
||||
let mut consequence: Vec<_> = consequence
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|e| self.fold_statement(e))
|
||||
.collect();
|
||||
consequence.reverse();
|
||||
|
||||
let mut alternative: Vec<_> = alternative
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|e| self.fold_statement(e))
|
||||
.collect();
|
||||
alternative.reverse();
|
||||
|
||||
vec![ZirStatement::IfElse(condition, consequence, alternative)]
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_block_expression<E: Fold<'ast, T>>(
|
||||
&mut self,
|
||||
block: BlockExpression<'ast, T, E>,
|
||||
) -> BlockExpression<'ast, T, E> {
|
||||
self.in_block += 1;
|
||||
|
||||
let value = box block.value.fold(self);
|
||||
let mut statements: Vec<_> = block
|
||||
.statements
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|s| self.fold_statement(s))
|
||||
.collect();
|
||||
statements.reverse();
|
||||
|
||||
let block = BlockExpression { value, statements };
|
||||
self.in_block -= 1;
|
||||
block
|
||||
}
|
||||
|
||||
fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> {
|
||||
self.used.insert(n.clone());
|
||||
n
|
||||
|
|
|
@ -991,12 +991,12 @@ fn fold_boolean_expression<'ast, T: Field>(
|
|||
typed::BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_field_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::FieldGt(box e1, box e2)
|
||||
zir::BooleanExpression::FieldLt(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_field_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::FieldGe(box e1, box e2)
|
||||
zir::BooleanExpression::FieldLe(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::UintLt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
|
@ -1011,12 +1011,12 @@ fn fold_boolean_expression<'ast, T: Field>(
|
|||
typed::BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_uint_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::UintGt(box e1, box e2)
|
||||
zir::BooleanExpression::UintLt(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_uint_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::UintGe(box e1, box e2)
|
||||
zir::BooleanExpression::UintLe(box e2, box e1)
|
||||
}
|
||||
typed::BooleanExpression::Or(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(statements_buffer, e1);
|
||||
|
|
|
@ -166,14 +166,6 @@ pub fn analyse<'ast, T: Field>(
|
|||
let r = ConditionRedefiner::redefine(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Extract panics");
|
||||
let r = PanicExtractor::extract(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
log::debug!("Static analyser: Remove dead code");
|
||||
let r = DeadCodeEliminator::eliminate(r);
|
||||
log::trace!("\n{}", r);
|
||||
|
||||
// convert to zir, removing complex types
|
||||
log::debug!("Static analyser: Convert to zir");
|
||||
let zir = Flattener::flatten(r);
|
||||
|
@ -184,6 +176,14 @@ pub fn analyse<'ast, T: Field>(
|
|||
let zir = ZirPropagator::propagate(zir).map_err(Error::from)?;
|
||||
log::trace!("\n{}", zir);
|
||||
|
||||
log::debug!("Static analyser: Extract panics");
|
||||
let zir = PanicExtractor::extract(zir);
|
||||
log::trace!("\n{}", zir);
|
||||
|
||||
log::debug!("Static analyser: Remove dead code");
|
||||
let zir = DeadCodeEliminator::eliminate(zir);
|
||||
log::trace!("\n{}", zir);
|
||||
|
||||
// optimize uint expressions
|
||||
log::debug!("Static analyser: Optimize uints");
|
||||
let zir = UintOptimizer::optimize(zir);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use zokrates_ast::typed::{
|
||||
folder::*, BooleanExpression, EqExpression, FieldElementExpression, RuntimeError, TypedProgram,
|
||||
TypedStatement, UBitwidth, UExpressionInner,
|
||||
use zokrates_ast::zir::{
|
||||
folder::*, BooleanExpression, FieldElementExpression, RuntimeError, UBitwidth, UExpression,
|
||||
UExpressionInner, ZirProgram, ZirStatement,
|
||||
};
|
||||
use zokrates_field::Field;
|
||||
|
||||
|
@ -8,23 +8,47 @@ use zokrates_field::Field;
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct PanicExtractor<'ast, T> {
|
||||
panic_buffer: Vec<(BooleanExpression<'ast, T>, RuntimeError)>,
|
||||
panic_buffer: Vec<ZirStatement<'ast, T>>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> PanicExtractor<'ast, T> {
|
||||
pub fn extract(p: TypedProgram<'ast, T>) -> TypedProgram<'ast, T> {
|
||||
pub fn extract(p: ZirProgram<'ast, T>) -> ZirProgram<'ast, T> {
|
||||
Self::default().fold_program(p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
|
||||
fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec<TypedStatement<'ast, T>> {
|
||||
let s = fold_statement(self, s);
|
||||
self.panic_buffer
|
||||
.drain(..)
|
||||
.map(|(b, e)| TypedStatement::Assertion(b, e))
|
||||
.chain(s)
|
||||
.collect()
|
||||
fn fold_statement(&mut self, s: ZirStatement<'ast, T>) -> Vec<ZirStatement<'ast, T>> {
|
||||
match s {
|
||||
ZirStatement::IfElse(condition, consequence, alternative) => {
|
||||
let condition = self.fold_boolean_expression(condition);
|
||||
let mut consequence_extractor = Self::default();
|
||||
let consequence = consequence
|
||||
.into_iter()
|
||||
.flat_map(|s| consequence_extractor.fold_statement(s))
|
||||
.collect();
|
||||
assert!(consequence_extractor.panic_buffer.is_empty());
|
||||
let mut alternative_extractor = Self::default();
|
||||
let alternative = alternative
|
||||
.into_iter()
|
||||
.flat_map(|s| alternative_extractor.fold_statement(s))
|
||||
.collect();
|
||||
assert!(alternative_extractor.panic_buffer.is_empty());
|
||||
|
||||
self.panic_buffer
|
||||
.drain(..)
|
||||
.chain(std::iter::once(ZirStatement::IfElse(
|
||||
condition,
|
||||
consequence,
|
||||
alternative,
|
||||
)))
|
||||
.collect()
|
||||
}
|
||||
s => {
|
||||
let s = fold_statement(self, s);
|
||||
self.panic_buffer.drain(..).chain(s).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
|
@ -35,15 +59,34 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
|
|||
FieldElementExpression::Div(box n, box d) => {
|
||||
let n = self.fold_field_expression(n);
|
||||
let d = self.fold_field_expression(d);
|
||||
self.panic_buffer.push((
|
||||
BooleanExpression::Not(box BooleanExpression::FieldEq(EqExpression::new(
|
||||
d.clone(),
|
||||
T::zero().into(),
|
||||
))),
|
||||
self.panic_buffer.push(ZirStatement::Assertion(
|
||||
BooleanExpression::Not(box BooleanExpression::FieldEq(
|
||||
box d.clone(),
|
||||
box FieldElementExpression::Number(T::zero()),
|
||||
)),
|
||||
RuntimeError::DivisionByZero,
|
||||
));
|
||||
FieldElementExpression::Div(box n, box d)
|
||||
}
|
||||
FieldElementExpression::Conditional(
|
||||
box condition,
|
||||
box consequence,
|
||||
box alternative,
|
||||
) => {
|
||||
let condition = self.fold_boolean_expression(condition);
|
||||
let mut consequence_extractor = Self::default();
|
||||
let consequence = consequence_extractor.fold_field_expression(consequence);
|
||||
let mut alternative_extractor = Self::default();
|
||||
let alternative = alternative_extractor.fold_field_expression(alternative);
|
||||
|
||||
self.panic_buffer.push(ZirStatement::IfElse(
|
||||
condition.clone(),
|
||||
consequence_extractor.panic_buffer.drain(..).collect(),
|
||||
alternative_extractor.panic_buffer.drain(..).collect(),
|
||||
));
|
||||
|
||||
FieldElementExpression::Conditional(box condition, box consequence, box alternative)
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
@ -57,11 +100,11 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
|
|||
UExpressionInner::Div(box n, box d) => {
|
||||
let n = self.fold_uint_expression(n);
|
||||
let d = self.fold_uint_expression(d);
|
||||
self.panic_buffer.push((
|
||||
BooleanExpression::Not(box BooleanExpression::UintEq(EqExpression::new(
|
||||
d.clone(),
|
||||
UExpressionInner::Value(0).annotate(b),
|
||||
))),
|
||||
self.panic_buffer.push(ZirStatement::Assertion(
|
||||
BooleanExpression::Not(box BooleanExpression::UintEq(
|
||||
box d.clone(),
|
||||
box UExpressionInner::Value(0).annotate(b),
|
||||
)),
|
||||
RuntimeError::DivisionByZero,
|
||||
));
|
||||
UExpressionInner::Div(box n, box d)
|
||||
|
@ -69,4 +112,62 @@ impl<'ast, T: Field> Folder<'ast, T> for PanicExtractor<'ast, T> {
|
|||
e => fold_uint_expression_inner(self, b, e),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
) -> BooleanExpression<'ast, T> {
|
||||
match e {
|
||||
e @ BooleanExpression::FieldLt(box FieldElementExpression::Number(_), _)
|
||||
| e @ BooleanExpression::FieldLt(_, box FieldElementExpression::Number(_))
|
||||
| e @ BooleanExpression::UintLt(
|
||||
box UExpression {
|
||||
inner: UExpressionInner::Value(_),
|
||||
..
|
||||
},
|
||||
_,
|
||||
)
|
||||
| e @ BooleanExpression::UintLt(
|
||||
_,
|
||||
box UExpression {
|
||||
inner: UExpressionInner::Value(_),
|
||||
..
|
||||
},
|
||||
) => fold_boolean_expression(self, e),
|
||||
BooleanExpression::FieldLt(box left, box right) => {
|
||||
let left = self.fold_field_expression(left);
|
||||
let right = self.fold_field_expression(right);
|
||||
|
||||
let bit_width = T::get_required_bits();
|
||||
|
||||
let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2)
|
||||
|
||||
let offset = FieldElementExpression::Number(T::from(2).pow(safe_width));
|
||||
let max = FieldElementExpression::Number(T::from(2).pow(safe_width + 1));
|
||||
|
||||
self.panic_buffer.push(ZirStatement::Assertion(
|
||||
BooleanExpression::And(
|
||||
box BooleanExpression::Not(box BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Sub(box left.clone(), box right.clone()),
|
||||
box offset.clone(),
|
||||
)),
|
||||
box BooleanExpression::FieldLt(
|
||||
box FieldElementExpression::Add(
|
||||
box offset,
|
||||
box FieldElementExpression::Sub(
|
||||
box right.clone(),
|
||||
box left.clone(),
|
||||
),
|
||||
),
|
||||
box max,
|
||||
),
|
||||
),
|
||||
RuntimeError::IncompleteDynamicRange,
|
||||
));
|
||||
|
||||
BooleanExpression::FieldLt(box left, box right)
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -554,9 +554,10 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
let e_str = e.to_string();
|
||||
let expr = self.fold_boolean_expression(e)?;
|
||||
match expr {
|
||||
BooleanExpression::Value(v) if !v => {
|
||||
BooleanExpression::Value(false) => {
|
||||
Err(Error::AssertionFailed(format!("{}: ({})", ty, e_str)))
|
||||
}
|
||||
BooleanExpression::Value(true) => Ok(vec![]),
|
||||
_ => Ok(vec![TypedStatement::Assertion(expr, ty)]),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,24 +114,6 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
|
||||
BooleanExpression::UintLe(box left, box right)
|
||||
}
|
||||
BooleanExpression::UintGt(box left, box right) => {
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left = force_reduce(left);
|
||||
let right = force_reduce(right);
|
||||
|
||||
BooleanExpression::UintGt(box left, box right)
|
||||
}
|
||||
BooleanExpression::UintGe(box left, box right) => {
|
||||
let left = self.fold_uint_expression(left);
|
||||
let right = self.fold_uint_expression(right);
|
||||
|
||||
let left = force_reduce(left);
|
||||
let right = force_reduce(right);
|
||||
|
||||
BooleanExpression::UintGe(box left, box right)
|
||||
}
|
||||
e => fold_boolean_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -317,34 +317,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
(e1, e2) => Ok(BooleanExpression::FieldLe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldGe(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
self.fold_field_expression(e2)?,
|
||||
) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 >= n2))
|
||||
}
|
||||
(e1, e2) => Ok(BooleanExpression::FieldGe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldGt(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
self.fold_field_expression(e2)?,
|
||||
) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 > n2))
|
||||
}
|
||||
(_, FieldElementExpression::Number(c)) if c == T::max_value() => {
|
||||
Ok(BooleanExpression::Value(false))
|
||||
}
|
||||
(FieldElementExpression::Number(c), _) if c == T::zero() => {
|
||||
Ok(BooleanExpression::Value(false))
|
||||
}
|
||||
(e1, e2) => Ok(BooleanExpression::FieldGt(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
match (
|
||||
self.fold_field_expression(e1)?,
|
||||
|
@ -384,28 +356,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
_ => Ok(BooleanExpression::UintLe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintGe(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
||||
match (e1.as_inner(), e2.as_inner()) {
|
||||
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
|
||||
Ok(BooleanExpression::Value(v1 >= v2))
|
||||
}
|
||||
_ => Ok(BooleanExpression::UintGe(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintGt(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
||||
match (e1.as_inner(), e2.as_inner()) {
|
||||
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
|
||||
Ok(BooleanExpression::Value(v1 > v2))
|
||||
}
|
||||
_ => Ok(BooleanExpression::UintGt(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
@ -1019,85 +969,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_le() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldLe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldLe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_ge() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGe(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_gt() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Number(Bn128Field::from(0)),
|
||||
box FieldElementExpression::Identifier("a".into()),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::FieldGt(
|
||||
box FieldElementExpression::Identifier("a".into()),
|
||||
box FieldElementExpression::Number(Bn128Field::max_value()),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn field_eq() {
|
||||
let mut propagator = ZirPropagator::default();
|
||||
|
@ -1140,69 +1011,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_le() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintLe(
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintLe(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_ge() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGe(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGe(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_gt() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGt(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(2).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
propagator.fold_boolean_expression(BooleanExpression::UintGt(
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
box UExpressionInner::Value(3).annotate(UBitwidth::B32),
|
||||
)),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint_eq() {
|
||||
let mut propagator = ZirPropagator::<Bn128Field>::default();
|
||||
|
|
65
zokrates_core_test/tests/tests/div.json
Normal file
65
zokrates_core_test/tests/tests/div.json
Normal file
|
@ -0,0 +1,65 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/div.zok",
|
||||
"max_constraint_count": 3,
|
||||
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Err": {
|
||||
"UnsatisfiedConstraint": {
|
||||
"left": "4",
|
||||
"right": "2",
|
||||
"error": "Inverse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Err": {
|
||||
"UnsatisfiedConstraint": {
|
||||
"left": "4",
|
||||
"right": "2",
|
||||
"error": "Inverse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["2", "2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["4", "2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "2"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
3
zokrates_core_test/tests/tests/div.zok
Normal file
3
zokrates_core_test/tests/tests/div.zok
Normal file
|
@ -0,0 +1,3 @@
|
|||
def main(field x, field y) -> field {
|
||||
return x / y;
|
||||
}
|
Loading…
Reference in a new issue