Merge pull request #1085 from Zokrates/eq-expression
Introduce EqExpression to encapsulate treatment of equality expressions
This commit is contained in:
commit
780668016a
7 changed files with 389 additions and 229 deletions
|
@ -2740,21 +2740,21 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
|||
|
||||
match (e1_checked, e2_checked) {
|
||||
(TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => {
|
||||
Ok(BooleanExpression::FieldEq(box e1, box e2).into())
|
||||
Ok(BooleanExpression::FieldEq(EqExpression::new(e1, e2)).into())
|
||||
}
|
||||
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
|
||||
Ok(BooleanExpression::BoolEq(box e1, box e2).into())
|
||||
Ok(BooleanExpression::BoolEq(EqExpression::new(e1, e2)).into())
|
||||
}
|
||||
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
|
||||
Ok(BooleanExpression::ArrayEq(box e1, box e2).into())
|
||||
Ok(BooleanExpression::ArrayEq(EqExpression::new(e1, e2)).into())
|
||||
}
|
||||
(TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => {
|
||||
Ok(BooleanExpression::StructEq(box e1, box e2).into())
|
||||
Ok(BooleanExpression::StructEq(EqExpression::new(e1, e2)).into())
|
||||
}
|
||||
(TypedExpression::Uint(e1), TypedExpression::Uint(e2))
|
||||
if e1.get_type() == e2.get_type() =>
|
||||
{
|
||||
Ok(BooleanExpression::UintEq(box e1, box e2).into())
|
||||
Ok(BooleanExpression::UintEq(EqExpression::new(e1, e2)).into())
|
||||
}
|
||||
(e1, e2) => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
|
|
|
@ -305,6 +305,14 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
fold_select_expression(self, statements_buffer, select)
|
||||
}
|
||||
|
||||
fn fold_eq_expression<E: Flatten<'ast, T>>(
|
||||
&mut self,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
eq: typed_absy::EqExpression<E>,
|
||||
) -> zir::BooleanExpression<'ast, T> {
|
||||
fold_eq_expression(self, statements_buffer, eq)
|
||||
}
|
||||
|
||||
fn fold_field_expression(
|
||||
&mut self,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
|
@ -802,6 +810,16 @@ fn conjunction_tree<'ast, T: Field>(
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_eq_expression<'ast, T: Field, E: Flatten<'ast, T>>(
|
||||
f: &mut Flattener<T>,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
e: typed_absy::EqExpression<E>,
|
||||
) -> zir::BooleanExpression<'ast, T> {
|
||||
let left = e.left.flatten(f, statements_buffer);
|
||||
let right = e.right.flatten(f, statements_buffer);
|
||||
conjunction_tree(&left, &right)
|
||||
}
|
||||
|
||||
fn fold_boolean_expression<'ast, T: Field>(
|
||||
f: &mut Flattener<T>,
|
||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||
|
@ -822,38 +840,11 @@ fn fold_boolean_expression<'ast, T: Field>(
|
|||
.unwrap()
|
||||
.id,
|
||||
),
|
||||
typed_absy::BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_field_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::FieldEq(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::BoolEq(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_boolean_expression(statements_buffer, e2);
|
||||
zir::BooleanExpression::BoolEq(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::ArrayEq(box e1, box e2) => {
|
||||
let e1 = f.fold_array_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_array_expression(statements_buffer, e2);
|
||||
|
||||
assert_eq!(e1.len(), e2.len());
|
||||
|
||||
conjunction_tree(&e1, &e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::StructEq(box e1, box e2) => {
|
||||
let e1 = f.fold_struct_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_struct_expression(statements_buffer, e2);
|
||||
|
||||
assert_eq!(e1.len(), e2.len());
|
||||
|
||||
conjunction_tree(&e1, &e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_uint_expression(statements_buffer, e2);
|
||||
|
||||
zir::BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
typed_absy::BooleanExpression::FieldEq(e) => f.fold_eq_expression(statements_buffer, e),
|
||||
typed_absy::BooleanExpression::BoolEq(e) => f.fold_eq_expression(statements_buffer, e),
|
||||
typed_absy::BooleanExpression::ArrayEq(e) => f.fold_eq_expression(statements_buffer, e),
|
||||
typed_absy::BooleanExpression::StructEq(e) => f.fold_eq_expression(statements_buffer, e),
|
||||
typed_absy::BooleanExpression::UintEq(e) => f.fold_eq_expression(statements_buffer, e),
|
||||
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(statements_buffer, e1);
|
||||
let e2 = f.fold_field_expression(statements_buffer, e2);
|
||||
|
|
|
@ -1133,6 +1133,45 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn fold_eq_expression<
|
||||
E: Expr<'ast, T> + PartialEq + Constant + Typed<'ast, T> + ResultFold<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
e: EqExpression<E>,
|
||||
) -> Result<EqOrBoolean<'ast, T, E>, Self::Error> {
|
||||
let left = e.left.fold(self)?;
|
||||
let right = e.right.fold(self)?;
|
||||
|
||||
if let (Ok(t_left), Ok(t_right)) = (
|
||||
ConcreteType::try_from(left.get_type()),
|
||||
ConcreteType::try_from(right.get_type()),
|
||||
) {
|
||||
if t_left != t_right {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot compare {} of type {} to {} of type {}",
|
||||
left, t_left, right, t_right
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// if the two expressions are the same, we can reduce to `true`.
|
||||
// Note that if they are different we cannot reduce to `false`: `a == 1` may still be `true` even though `a` and `1` are different expressions
|
||||
if left == right {
|
||||
return Ok(EqOrBoolean::Boolean(BooleanExpression::Value(true)));
|
||||
}
|
||||
|
||||
// if both expressions are constant, we can reduce the equality check after we put them in canonical form
|
||||
if left.is_constant() && right.is_constant() {
|
||||
let left = left.into_canonical_constant();
|
||||
let right = right.into_canonical_constant();
|
||||
Ok(EqOrBoolean::Boolean(BooleanExpression::Value(
|
||||
left == right,
|
||||
)))
|
||||
} else {
|
||||
Ok(EqOrBoolean::Eq(EqExpression::new(left, right)))
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_boolean_expression(
|
||||
&mut self,
|
||||
e: BooleanExpression<'ast, T>,
|
||||
|
@ -1150,73 +1189,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
},
|
||||
None => Ok(BooleanExpression::Identifier(id)),
|
||||
},
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 == n2))
|
||||
}
|
||||
(e1, e2) => Ok(BooleanExpression::FieldEq(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = self.fold_uint_expression(e1)?;
|
||||
let e2 = self.fold_uint_expression(e2)?;
|
||||
|
||||
match (e1.as_inner(), e2.as_inner()) {
|
||||
(UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 == n2))
|
||||
}
|
||||
_ => Ok(BooleanExpression::UintEq(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::BoolEq(box e1, box e2) => {
|
||||
let e1 = self.fold_boolean_expression(e1)?;
|
||||
let e2 = self.fold_boolean_expression(e2)?;
|
||||
|
||||
match (e1, e2) {
|
||||
(BooleanExpression::Value(n1), BooleanExpression::Value(n2)) => {
|
||||
Ok(BooleanExpression::Value(n1 == n2))
|
||||
}
|
||||
(e1, e2) => Ok(BooleanExpression::BoolEq(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
BooleanExpression::ArrayEq(box e1, box e2) => {
|
||||
let e1 = self.fold_array_expression(e1)?;
|
||||
let e2 = self.fold_array_expression(e2)?;
|
||||
|
||||
if let (Ok(t1), Ok(t2)) = (
|
||||
ConcreteType::try_from(e1.get_type()),
|
||||
ConcreteType::try_from(e2.get_type()),
|
||||
) {
|
||||
if t1 != t2 {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot compare {} of type {} to {} of type {}",
|
||||
e1, t1, e2, t2
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(BooleanExpression::ArrayEq(box e1, box e2))
|
||||
}
|
||||
BooleanExpression::StructEq(box e1, box e2) => {
|
||||
let e1 = self.fold_struct_expression(e1)?;
|
||||
let e2 = self.fold_struct_expression(e2)?;
|
||||
|
||||
let t1 = e1.get_type();
|
||||
let t2 = e2.get_type();
|
||||
|
||||
if t1 != t2 {
|
||||
return Err(Error::Type(format!(
|
||||
"Cannot compare {} of type {} to {} of type {}",
|
||||
e1, t1, e2, t2
|
||||
)));
|
||||
};
|
||||
|
||||
Ok(BooleanExpression::StructEq(box e1, box e2))
|
||||
}
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = self.fold_field_expression(e1)?;
|
||||
let e2 = self.fold_field_expression(e2)?;
|
||||
|
@ -1522,63 +1494,219 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn field_eq() {
|
||||
let e_true = BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
);
|
||||
let e_constant_true = BooleanExpression::FieldEq(EqExpression::new(
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
));
|
||||
|
||||
let e_false = BooleanExpression::FieldEq(
|
||||
box FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
);
|
||||
let e_constant_false = BooleanExpression::FieldEq(EqExpression::new(
|
||||
FieldElementExpression::Number(Bn128Field::from(4)),
|
||||
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||
));
|
||||
|
||||
let e_identifier_true: BooleanExpression<Bn128Field> =
|
||||
BooleanExpression::FieldEq(EqExpression::new(
|
||||
FieldElementExpression::Identifier("a".into()),
|
||||
FieldElementExpression::Identifier("a".into()),
|
||||
));
|
||||
|
||||
let e_identifier_unchanged: BooleanExpression<Bn128Field> =
|
||||
BooleanExpression::FieldEq(EqExpression::new(
|
||||
FieldElementExpression::Identifier("a".into()),
|
||||
FieldElementExpression::Identifier("b".into()),
|
||||
));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_true),
|
||||
.fold_boolean_expression(e_constant_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_false),
|
||||
.fold_boolean_expression(e_constant_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Ok(e_identifier_unchanged)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bool_eq() {
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(false)
|
||||
)),
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(false)
|
||||
))),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(true)
|
||||
)),
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(true)
|
||||
))),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(
|
||||
box BooleanExpression::Value(true),
|
||||
box BooleanExpression::Value(false)
|
||||
)),
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(true),
|
||||
BooleanExpression::Value(false)
|
||||
))),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(
|
||||
box BooleanExpression::Value(false),
|
||||
box BooleanExpression::Value(true)
|
||||
)),
|
||||
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
|
||||
BooleanExpression::Value(false),
|
||||
BooleanExpression::Value(true)
|
||||
))),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_eq() {
|
||||
let e_constant_true = BooleanExpression::ArrayEq(EqExpression::new(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
));
|
||||
|
||||
let e_constant_false = BooleanExpression::ArrayEq(EqExpression::new(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(4usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
));
|
||||
|
||||
let e_identifier_true: BooleanExpression<Bn128Field> =
|
||||
BooleanExpression::ArrayEq(EqExpression::new(
|
||||
ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
));
|
||||
|
||||
let e_identifier_unchanged: BooleanExpression<Bn128Field> =
|
||||
BooleanExpression::ArrayEq(EqExpression::new(
|
||||
ArrayExpressionInner::Identifier("a".into())
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
ArrayExpressionInner::Identifier("b".into())
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
));
|
||||
|
||||
let e_non_canonical_true = BooleanExpression::ArrayEq(EqExpression::new(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Spread(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
)
|
||||
.into()]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
));
|
||||
|
||||
let e_non_canonical_false = BooleanExpression::ArrayEq(EqExpression::new(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Spread(
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32)
|
||||
.into(),
|
||||
)
|
||||
.into()]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
ArrayExpressionInner::Value(
|
||||
vec![TypedExpressionOrSpread::Expression(
|
||||
FieldElementExpression::Number(Bn128Field::from(4usize)).into(),
|
||||
)]
|
||||
.into(),
|
||||
)
|
||||
.annotate(Type::FieldElement, 1u32),
|
||||
));
|
||||
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_constant_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_identifier_unchanged.clone()),
|
||||
Ok(e_identifier_unchanged)
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_non_canonical_true),
|
||||
Ok(BooleanExpression::Value(true))
|
||||
);
|
||||
assert_eq!(
|
||||
Propagator::with_constants(&mut Constants::new())
|
||||
.fold_boolean_expression(e_non_canonical_false),
|
||||
Ok(BooleanExpression::Value(false))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -56,10 +56,10 @@ impl<'ast> VariableWriteRemover {
|
|||
.map(|i| match inner_ty {
|
||||
Type::Int => unreachable!(),
|
||||
Type::Array(..) => ArrayExpression::conditional(
|
||||
BooleanExpression::UintEq(
|
||||
box i.into(),
|
||||
box head.clone(),
|
||||
),
|
||||
BooleanExpression::UintEq(EqExpression::new(
|
||||
i.into(),
|
||||
head.clone(),
|
||||
)),
|
||||
match Self::choose_many(
|
||||
ArrayExpression::select(base.clone(), i).into(),
|
||||
tail.clone(),
|
||||
|
@ -77,10 +77,10 @@ impl<'ast> VariableWriteRemover {
|
|||
)
|
||||
.into(),
|
||||
Type::Struct(..) => StructExpression::conditional(
|
||||
BooleanExpression::UintEq(
|
||||
box i.into(),
|
||||
box head.clone(),
|
||||
),
|
||||
BooleanExpression::UintEq(EqExpression::new(
|
||||
i.into(),
|
||||
head.clone(),
|
||||
)),
|
||||
match Self::choose_many(
|
||||
StructExpression::select(base.clone(), i).into(),
|
||||
tail.clone(),
|
||||
|
@ -98,10 +98,10 @@ impl<'ast> VariableWriteRemover {
|
|||
)
|
||||
.into(),
|
||||
Type::FieldElement => FieldElementExpression::conditional(
|
||||
BooleanExpression::UintEq(
|
||||
box i.into(),
|
||||
box head.clone(),
|
||||
),
|
||||
BooleanExpression::UintEq(EqExpression::new(
|
||||
i.into(),
|
||||
head.clone(),
|
||||
)),
|
||||
match Self::choose_many(
|
||||
FieldElementExpression::select(base.clone(), i)
|
||||
.into(),
|
||||
|
@ -120,10 +120,10 @@ impl<'ast> VariableWriteRemover {
|
|||
)
|
||||
.into(),
|
||||
Type::Boolean => BooleanExpression::conditional(
|
||||
BooleanExpression::UintEq(
|
||||
box i.into(),
|
||||
box head.clone(),
|
||||
),
|
||||
BooleanExpression::UintEq(EqExpression::new(
|
||||
i.into(),
|
||||
head.clone(),
|
||||
)),
|
||||
match Self::choose_many(
|
||||
BooleanExpression::select(base.clone(), i).into(),
|
||||
tail.clone(),
|
||||
|
@ -141,10 +141,10 @@ impl<'ast> VariableWriteRemover {
|
|||
)
|
||||
.into(),
|
||||
Type::Uint(..) => UExpression::conditional(
|
||||
BooleanExpression::UintEq(
|
||||
box i.into(),
|
||||
box head.clone(),
|
||||
),
|
||||
BooleanExpression::UintEq(EqExpression::new(
|
||||
i.into(),
|
||||
head.clone(),
|
||||
)),
|
||||
match Self::choose_many(
|
||||
UExpression::select(base.clone(), i).into(),
|
||||
tail.clone(),
|
||||
|
|
|
@ -315,6 +315,13 @@ pub trait Folder<'ast, T: Field>: Sized {
|
|||
fold_member_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_eq_expression<E: Expr<'ast, T> + PartialEq + Constant + Fold<'ast, T>>(
|
||||
&mut self,
|
||||
e: EqExpression<E>,
|
||||
) -> EqOrBoolean<'ast, T, E> {
|
||||
fold_eq_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_function_call_expression<
|
||||
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
|
||||
>(
|
||||
|
@ -684,6 +691,13 @@ pub fn fold_member_expression<
|
|||
))
|
||||
}
|
||||
|
||||
pub fn fold_eq_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: EqExpression<E>,
|
||||
) -> EqOrBoolean<'ast, T, E> {
|
||||
EqOrBoolean::Eq(EqExpression::new(e.left.fold(f), e.right.fold(f)))
|
||||
}
|
||||
|
||||
pub fn fold_select_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
|
@ -715,31 +729,26 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
BooleanExpression::Block(block) => BooleanExpression::Block(f.fold_block_expression(block)),
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
BooleanExpression::FieldEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::BoolEq(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1);
|
||||
let e2 = f.fold_boolean_expression(e2);
|
||||
BooleanExpression::BoolEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::ArrayEq(box e1, box e2) => {
|
||||
let e1 = f.fold_array_expression(e1);
|
||||
let e2 = f.fold_array_expression(e2);
|
||||
BooleanExpression::ArrayEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::StructEq(box e1, box e2) => {
|
||||
let e1 = f.fold_struct_expression(e1);
|
||||
let e2 = f.fold_struct_expression(e2);
|
||||
BooleanExpression::StructEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1);
|
||||
let e2 = f.fold_uint_expression(e2);
|
||||
BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldEq(e) => match f.fold_eq_expression(e) {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::BoolEq(e) => match f.fold_eq_expression(e) {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::BoolEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::ArrayEq(e) => match f.fold_eq_expression(e) {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::ArrayEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::StructEq(e) => match f.fold_eq_expression(e) {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::StructEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::UintEq(e) => match f.fold_eq_expression(e) {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::UintEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1);
|
||||
let e2 = f.fold_field_expression(e2);
|
||||
|
|
|
@ -878,6 +878,28 @@ impl<'ast, T> TypedExpressionListInner<'ast, T> {
|
|||
TypedExpressionList { inner: self, types }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct EqExpression<E> {
|
||||
pub left: Box<E>,
|
||||
pub right: Box<E>,
|
||||
}
|
||||
|
||||
impl<E> EqExpression<E> {
|
||||
pub fn new(left: E, right: E) -> Self {
|
||||
EqExpression {
|
||||
left: box left,
|
||||
right: box right,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: fmt::Display> fmt::Display for EqExpression<E> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{} == {}", self.left, self.right)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
|
||||
pub struct BlockExpression<'ast, T, E> {
|
||||
pub statements: Vec<TypedStatement<'ast, T>>,
|
||||
|
@ -1141,20 +1163,11 @@ pub enum BooleanExpression<'ast, T> {
|
|||
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
FieldEq(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
BoolEq(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
),
|
||||
ArrayEq(Box<ArrayExpression<'ast, T>>, Box<ArrayExpression<'ast, T>>),
|
||||
StructEq(
|
||||
Box<StructExpression<'ast, T>>,
|
||||
Box<StructExpression<'ast, T>>,
|
||||
),
|
||||
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||
FieldEq(EqExpression<FieldElementExpression<'ast, T>>),
|
||||
BoolEq(EqExpression<BooleanExpression<'ast, T>>),
|
||||
ArrayEq(EqExpression<ArrayExpression<'ast, T>>),
|
||||
StructEq(EqExpression<StructExpression<'ast, T>>),
|
||||
UintEq(EqExpression<UExpression<'ast, T>>),
|
||||
Or(
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
Box<BooleanExpression<'ast, T>>,
|
||||
|
@ -1205,7 +1218,7 @@ impl<'ast, T> IntoIterator for ArrayValue<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> ArrayValue<'ast, T> {
|
||||
impl<'ast, T: Field> ArrayValue<'ast, T> {
|
||||
fn expression_at_aux<
|
||||
U: Select<'ast, T> + From<TypedExpression<'ast, T>> + Into<TypedExpression<'ast, T>>,
|
||||
>(
|
||||
|
@ -1529,11 +1542,11 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
|||
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
|
||||
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
|
||||
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::StructEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
|
||||
BooleanExpression::FieldEq(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::BoolEq(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::ArrayEq(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::StructEq(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::UintEq(ref e) => write!(f, "{}", e),
|
||||
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
|
||||
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
|
||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||
|
@ -1625,7 +1638,7 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
|
|||
|
||||
// Common behaviour across expressions
|
||||
|
||||
pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
|
||||
pub trait Expr<'ast, T>: fmt::Display + From<TypedExpression<'ast, T>> {
|
||||
type Inner;
|
||||
type Ty: Clone + IntoTypes<'ast, T>;
|
||||
|
||||
|
@ -1638,7 +1651,7 @@ pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
|
|||
fn as_inner_mut(&mut self) -> &mut Self::Inner;
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
|
@ -1659,7 +1672,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
|
@ -1680,7 +1693,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> {
|
||||
type Inner = UExpressionInner<'ast, T>;
|
||||
type Ty = UBitwidth;
|
||||
|
||||
|
@ -1701,7 +1714,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> {
|
||||
type Inner = StructExpressionInner<'ast, T>;
|
||||
type Ty = StructType<'ast, T>;
|
||||
|
||||
|
@ -1722,7 +1735,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
||||
type Inner = ArrayExpressionInner<'ast, T>;
|
||||
type Ty = ArrayType<'ast, T>;
|
||||
|
||||
|
@ -1743,7 +1756,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> {
|
||||
type Inner = Self;
|
||||
type Ty = Type<'ast, T>;
|
||||
|
||||
|
@ -1764,7 +1777,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
impl<'ast, T: Field> Expr<'ast, T> for TypedExpressionList<'ast, T> {
|
||||
type Inner = TypedExpressionListInner<'ast, T>;
|
||||
type Ty = Types<'ast, T>;
|
||||
|
||||
|
@ -1796,6 +1809,11 @@ pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> {
|
|||
Expression(E::Inner),
|
||||
}
|
||||
|
||||
pub enum EqOrBoolean<'ast, T, E> {
|
||||
Eq(EqExpression<E>),
|
||||
Boolean(BooleanExpression<'ast, T>),
|
||||
}
|
||||
|
||||
pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||
Member(MemberExpression<'ast, T, E>),
|
||||
Expression(E::Inner),
|
||||
|
@ -1941,7 +1959,7 @@ impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Clone> Select<'ast, T> for TypedExpression<'ast, T> {
|
||||
impl<'ast, T: Field> Select<'ast, T> for TypedExpression<'ast, T> {
|
||||
fn select<I: Into<UExpression<'ast, T>>>(array: ArrayExpression<'ast, T>, index: I) -> Self {
|
||||
match *array.ty().ty {
|
||||
Type::Array(..) => ArrayExpression::select(array, index).into(),
|
||||
|
|
|
@ -218,6 +218,15 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
|||
fold_member_expression(self, ty, e)
|
||||
}
|
||||
|
||||
fn fold_eq_expression<
|
||||
E: Expr<'ast, T> + Typed<'ast, T> + PartialEq + Constant + ResultFold<'ast, T>,
|
||||
>(
|
||||
&mut self,
|
||||
e: EqExpression<E>,
|
||||
) -> Result<EqOrBoolean<'ast, T, E>, Self::Error> {
|
||||
fold_eq_expression(self, e)
|
||||
}
|
||||
|
||||
fn fold_select_expression<
|
||||
E: Expr<'ast, T>
|
||||
+ Select<'ast, T>
|
||||
|
@ -736,6 +745,16 @@ pub fn fold_member_expression<
|
|||
)))
|
||||
}
|
||||
|
||||
pub fn fold_eq_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>(
|
||||
f: &mut F,
|
||||
e: EqExpression<E>,
|
||||
) -> Result<EqOrBoolean<'ast, T, E>, F::Error> {
|
||||
Ok(EqOrBoolean::Eq(EqExpression::new(
|
||||
e.left.fold(f)?,
|
||||
e.right.fold(f)?,
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn fold_select_expression<
|
||||
'ast,
|
||||
T: Field,
|
||||
|
@ -788,31 +807,26 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
}
|
||||
BooleanExpression::Value(v) => BooleanExpression::Value(v),
|
||||
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
|
||||
BooleanExpression::FieldEq(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
BooleanExpression::FieldEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::BoolEq(box e1, box e2) => {
|
||||
let e1 = f.fold_boolean_expression(e1)?;
|
||||
let e2 = f.fold_boolean_expression(e2)?;
|
||||
BooleanExpression::BoolEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::ArrayEq(box e1, box e2) => {
|
||||
let e1 = f.fold_array_expression(e1)?;
|
||||
let e2 = f.fold_array_expression(e2)?;
|
||||
BooleanExpression::ArrayEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::StructEq(box e1, box e2) => {
|
||||
let e1 = f.fold_struct_expression(e1)?;
|
||||
let e2 = f.fold_struct_expression(e2)?;
|
||||
BooleanExpression::StructEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::UintEq(box e1, box e2) => {
|
||||
let e1 = f.fold_uint_expression(e1)?;
|
||||
let e2 = f.fold_uint_expression(e2)?;
|
||||
BooleanExpression::UintEq(box e1, box e2)
|
||||
}
|
||||
BooleanExpression::FieldEq(e) => match f.fold_eq_expression(e)? {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::BoolEq(e) => match f.fold_eq_expression(e)? {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::BoolEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::ArrayEq(e) => match f.fold_eq_expression(e)? {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::ArrayEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::StructEq(e) => match f.fold_eq_expression(e)? {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::StructEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::UintEq(e) => match f.fold_eq_expression(e)? {
|
||||
EqOrBoolean::Eq(e) => BooleanExpression::UintEq(e),
|
||||
EqOrBoolean::Boolean(u) => u,
|
||||
},
|
||||
BooleanExpression::FieldLt(box e1, box e2) => {
|
||||
let e1 = f.fold_field_expression(e1)?;
|
||||
let e2 = f.fold_field_expression(e2)?;
|
||||
|
|
Loading…
Reference in a new issue