From 791a9b66d5417609188a543f13e155a76c00f06a Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 13 Jun 2023 00:36:57 +0200 Subject: [PATCH] detect division by zero in typed propagation --- zokrates_analysis/src/propagation.rs | 44 +++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 2e047c64..03c98c38 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -33,6 +33,7 @@ pub enum Error { InvalidValue(String), OutOfBounds(u128, u128), VariableLength(String), + DivisionByZero, } impl fmt::Display for Error { @@ -47,6 +48,9 @@ impl fmt::Display for Error { index, size ), Error::VariableLength(message) => write!(f, "{}", message), + Error::DivisionByZero => { + write!(f, "Division by zero detected during static analysis",) + } } } } @@ -924,12 +928,16 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { let left = self.fold_field_expression(*e.left)?; let right = self.fold_field_expression(*e.right)?; - Ok(match (left, right) { - (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => { - FieldElementExpression::Value(ValueExpression::new(n1.value / n2.value)) + match (left, right) { + (_, FieldElementExpression::Value(n)) if n.value == T::from(0) => { + Err(Error::DivisionByZero) } - (e1, e2) => e1 / e2, - }) + (e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e), + (FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => Ok( + FieldElementExpression::Value(ValueExpression::new(n1.value / n2.value)), + ), + (e1, e2) => Ok(e1 / e2), + } } FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? { FieldElementExpression::Value(n) => { @@ -1571,14 +1579,30 @@ mod tests { #[test] fn div() { - let e = FieldElementExpression::div( - FieldElementExpression::value(Bn128Field::from(6)), - FieldElementExpression::value(Bn128Field::from(2)), + let mut propagator = Propagator::default(); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(6)), + FieldElementExpression::value(Bn128Field::from(2)), + )), + Ok(FieldElementExpression::value(Bn128Field::from(3))) ); assert_eq!( - Propagator::default().fold_field_expression(e), - Ok(FieldElementExpression::value(Bn128Field::from(3))) + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::identifier("a".into()), + FieldElementExpression::value(Bn128Field::from(1)), + )), + Ok(FieldElementExpression::identifier("a".into())) + ); + + assert_eq!( + propagator.fold_field_expression(FieldElementExpression::div( + FieldElementExpression::value(Bn128Field::from(6)), + FieldElementExpression::value(Bn128Field::from(0)), + )), + Err(Error::DivisionByZero) ); }