detect division by zero in typed propagation
This commit is contained in:
parent
cd15f020e0
commit
791a9b66d5
1 changed files with 34 additions and 10 deletions
|
@ -33,6 +33,7 @@ pub enum Error {
|
||||||
InvalidValue(String),
|
InvalidValue(String),
|
||||||
OutOfBounds(u128, u128),
|
OutOfBounds(u128, u128),
|
||||||
VariableLength(String),
|
VariableLength(String),
|
||||||
|
DivisionByZero,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for Error {
|
impl fmt::Display for Error {
|
||||||
|
@ -47,6 +48,9 @@ impl fmt::Display for Error {
|
||||||
index, size
|
index, size
|
||||||
),
|
),
|
||||||
Error::VariableLength(message) => write!(f, "{}", message),
|
Error::VariableLength(message) => write!(f, "{}", message),
|
||||||
|
Error::DivisionByZero => {
|
||||||
|
write!(f, "Division by zero detected during static analysis",)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -924,12 +928,16 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
|
||||||
let left = self.fold_field_expression(*e.left)?;
|
let left = self.fold_field_expression(*e.left)?;
|
||||||
let right = self.fold_field_expression(*e.right)?;
|
let right = self.fold_field_expression(*e.right)?;
|
||||||
|
|
||||||
Ok(match (left, right) {
|
match (left, right) {
|
||||||
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
|
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
|
||||||
FieldElementExpression::Value(ValueExpression::new(n1.value / n2.value))
|
Err(Error::DivisionByZero)
|
||||||
|
}
|
||||||
|
(e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e),
|
||||||
|
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => Ok(
|
||||||
|
FieldElementExpression::Value(ValueExpression::new(n1.value / n2.value)),
|
||||||
|
),
|
||||||
|
(e1, e2) => Ok(e1 / e2),
|
||||||
}
|
}
|
||||||
(e1, e2) => e1 / e2,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? {
|
FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? {
|
||||||
FieldElementExpression::Value(n) => {
|
FieldElementExpression::Value(n) => {
|
||||||
|
@ -1571,14 +1579,30 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn div() {
|
fn div() {
|
||||||
let e = FieldElementExpression::div(
|
let mut propagator = Propagator::default();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::div(
|
||||||
FieldElementExpression::value(Bn128Field::from(6)),
|
FieldElementExpression::value(Bn128Field::from(6)),
|
||||||
FieldElementExpression::value(Bn128Field::from(2)),
|
FieldElementExpression::value(Bn128Field::from(2)),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::value(Bn128Field::from(3)))
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Propagator::default().fold_field_expression(e),
|
propagator.fold_field_expression(FieldElementExpression::div(
|
||||||
Ok(FieldElementExpression::value(Bn128Field::from(3)))
|
FieldElementExpression::identifier("a".into()),
|
||||||
|
FieldElementExpression::value(Bn128Field::from(1)),
|
||||||
|
)),
|
||||||
|
Ok(FieldElementExpression::identifier("a".into()))
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
propagator.fold_field_expression(FieldElementExpression::div(
|
||||||
|
FieldElementExpression::value(Bn128Field::from(6)),
|
||||||
|
FieldElementExpression::value(Bn128Field::from(0)),
|
||||||
|
)),
|
||||||
|
Err(Error::DivisionByZero)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue