1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #1085 from Zokrates/eq-expression

Introduce EqExpression to encapsulate treatment of equality expressions
This commit is contained in:
Thibaut Schaeffer 2022-01-12 21:46:48 +01:00 committed by GitHub
commit 780668016a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 389 additions and 229 deletions

View file

@ -2740,21 +2740,21 @@ impl<'ast, T: Field> Checker<'ast, T> {
match (e1_checked, e2_checked) {
(TypedExpression::FieldElement(e1), TypedExpression::FieldElement(e2)) => {
Ok(BooleanExpression::FieldEq(box e1, box e2).into())
Ok(BooleanExpression::FieldEq(EqExpression::new(e1, e2)).into())
}
(TypedExpression::Boolean(e1), TypedExpression::Boolean(e2)) => {
Ok(BooleanExpression::BoolEq(box e1, box e2).into())
Ok(BooleanExpression::BoolEq(EqExpression::new(e1, e2)).into())
}
(TypedExpression::Array(e1), TypedExpression::Array(e2)) => {
Ok(BooleanExpression::ArrayEq(box e1, box e2).into())
Ok(BooleanExpression::ArrayEq(EqExpression::new(e1, e2)).into())
}
(TypedExpression::Struct(e1), TypedExpression::Struct(e2)) => {
Ok(BooleanExpression::StructEq(box e1, box e2).into())
Ok(BooleanExpression::StructEq(EqExpression::new(e1, e2)).into())
}
(TypedExpression::Uint(e1), TypedExpression::Uint(e2))
if e1.get_type() == e2.get_type() =>
{
Ok(BooleanExpression::UintEq(box e1, box e2).into())
Ok(BooleanExpression::UintEq(EqExpression::new(e1, e2)).into())
}
(e1, e2) => Err(ErrorInner {
pos: Some(pos),

View file

@ -305,6 +305,14 @@ impl<'ast, T: Field> Flattener<T> {
fold_select_expression(self, statements_buffer, select)
}
fn fold_eq_expression<E: Flatten<'ast, T>>(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
eq: typed_absy::EqExpression<E>,
) -> zir::BooleanExpression<'ast, T> {
fold_eq_expression(self, statements_buffer, eq)
}
fn fold_field_expression(
&mut self,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
@ -802,6 +810,16 @@ fn conjunction_tree<'ast, T: Field>(
}
}
fn fold_eq_expression<'ast, T: Field, E: Flatten<'ast, T>>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
e: typed_absy::EqExpression<E>,
) -> zir::BooleanExpression<'ast, T> {
let left = e.left.flatten(f, statements_buffer);
let right = e.right.flatten(f, statements_buffer);
conjunction_tree(&left, &right)
}
fn fold_boolean_expression<'ast, T: Field>(
f: &mut Flattener<T>,
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
@ -822,38 +840,11 @@ fn fold_boolean_expression<'ast, T: Field>(
.unwrap()
.id,
),
typed_absy::BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);
zir::BooleanExpression::FieldEq(box e1, box e2)
}
typed_absy::BooleanExpression::BoolEq(box e1, box e2) => {
let e1 = f.fold_boolean_expression(statements_buffer, e1);
let e2 = f.fold_boolean_expression(statements_buffer, e2);
zir::BooleanExpression::BoolEq(box e1, box e2)
}
typed_absy::BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(statements_buffer, e1);
let e2 = f.fold_array_expression(statements_buffer, e2);
assert_eq!(e1.len(), e2.len());
conjunction_tree(&e1, &e2)
}
typed_absy::BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(statements_buffer, e1);
let e2 = f.fold_struct_expression(statements_buffer, e2);
assert_eq!(e1.len(), e2.len());
conjunction_tree(&e1, &e2)
}
typed_absy::BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(statements_buffer, e1);
let e2 = f.fold_uint_expression(statements_buffer, e2);
zir::BooleanExpression::UintEq(box e1, box e2)
}
typed_absy::BooleanExpression::FieldEq(e) => f.fold_eq_expression(statements_buffer, e),
typed_absy::BooleanExpression::BoolEq(e) => f.fold_eq_expression(statements_buffer, e),
typed_absy::BooleanExpression::ArrayEq(e) => f.fold_eq_expression(statements_buffer, e),
typed_absy::BooleanExpression::StructEq(e) => f.fold_eq_expression(statements_buffer, e),
typed_absy::BooleanExpression::UintEq(e) => f.fold_eq_expression(statements_buffer, e),
typed_absy::BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = f.fold_field_expression(statements_buffer, e1);
let e2 = f.fold_field_expression(statements_buffer, e2);

View file

@ -1133,6 +1133,45 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
}
fn fold_eq_expression<
E: Expr<'ast, T> + PartialEq + Constant + Typed<'ast, T> + ResultFold<'ast, T>,
>(
&mut self,
e: EqExpression<E>,
) -> Result<EqOrBoolean<'ast, T, E>, Self::Error> {
let left = e.left.fold(self)?;
let right = e.right.fold(self)?;
if let (Ok(t_left), Ok(t_right)) = (
ConcreteType::try_from(left.get_type()),
ConcreteType::try_from(right.get_type()),
) {
if t_left != t_right {
return Err(Error::Type(format!(
"Cannot compare {} of type {} to {} of type {}",
left, t_left, right, t_right
)));
}
};
// if the two expressions are the same, we can reduce to `true`.
// Note that if they are different we cannot reduce to `false`: `a == 1` may still be `true` even though `a` and `1` are different expressions
if left == right {
return Ok(EqOrBoolean::Boolean(BooleanExpression::Value(true)));
}
// if both expressions are constant, we can reduce the equality check after we put them in canonical form
if left.is_constant() && right.is_constant() {
let left = left.into_canonical_constant();
let right = right.into_canonical_constant();
Ok(EqOrBoolean::Boolean(BooleanExpression::Value(
left == right,
)))
} else {
Ok(EqOrBoolean::Eq(EqExpression::new(left, right)))
}
}
fn fold_boolean_expression(
&mut self,
e: BooleanExpression<'ast, T>,
@ -1150,73 +1189,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
},
None => Ok(BooleanExpression::Identifier(id)),
},
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = self.fold_field_expression(e1)?;
let e2 = self.fold_field_expression(e2)?;
match (e1, e2) {
(FieldElementExpression::Number(n1), FieldElementExpression::Number(n2)) => {
Ok(BooleanExpression::Value(n1 == n2))
}
(e1, e2) => Ok(BooleanExpression::FieldEq(box e1, box e2)),
}
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = self.fold_uint_expression(e1)?;
let e2 = self.fold_uint_expression(e2)?;
match (e1.as_inner(), e2.as_inner()) {
(UExpressionInner::Value(n1), UExpressionInner::Value(n2)) => {
Ok(BooleanExpression::Value(n1 == n2))
}
_ => Ok(BooleanExpression::UintEq(box e1, box e2)),
}
}
BooleanExpression::BoolEq(box e1, box e2) => {
let e1 = self.fold_boolean_expression(e1)?;
let e2 = self.fold_boolean_expression(e2)?;
match (e1, e2) {
(BooleanExpression::Value(n1), BooleanExpression::Value(n2)) => {
Ok(BooleanExpression::Value(n1 == n2))
}
(e1, e2) => Ok(BooleanExpression::BoolEq(box e1, box e2)),
}
}
BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = self.fold_array_expression(e1)?;
let e2 = self.fold_array_expression(e2)?;
if let (Ok(t1), Ok(t2)) = (
ConcreteType::try_from(e1.get_type()),
ConcreteType::try_from(e2.get_type()),
) {
if t1 != t2 {
return Err(Error::Type(format!(
"Cannot compare {} of type {} to {} of type {}",
e1, t1, e2, t2
)));
}
};
Ok(BooleanExpression::ArrayEq(box e1, box e2))
}
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = self.fold_struct_expression(e1)?;
let e2 = self.fold_struct_expression(e2)?;
let t1 = e1.get_type();
let t2 = e2.get_type();
if t1 != t2 {
return Err(Error::Type(format!(
"Cannot compare {} of type {} to {} of type {}",
e1, t1, e2, t2
)));
};
Ok(BooleanExpression::StructEq(box e1, box e2))
}
BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = self.fold_field_expression(e1)?;
let e2 = self.fold_field_expression(e2)?;
@ -1522,63 +1494,219 @@ mod tests {
#[test]
fn field_eq() {
let e_true = BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(2)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
let e_constant_true = BooleanExpression::FieldEq(EqExpression::new(
FieldElementExpression::Number(Bn128Field::from(2)),
FieldElementExpression::Number(Bn128Field::from(2)),
));
let e_false = BooleanExpression::FieldEq(
box FieldElementExpression::Number(Bn128Field::from(4)),
box FieldElementExpression::Number(Bn128Field::from(2)),
);
let e_constant_false = BooleanExpression::FieldEq(EqExpression::new(
FieldElementExpression::Number(Bn128Field::from(4)),
FieldElementExpression::Number(Bn128Field::from(2)),
));
let e_identifier_true: BooleanExpression<Bn128Field> =
BooleanExpression::FieldEq(EqExpression::new(
FieldElementExpression::Identifier("a".into()),
FieldElementExpression::Identifier("a".into()),
));
let e_identifier_unchanged: BooleanExpression<Bn128Field> =
BooleanExpression::FieldEq(EqExpression::new(
FieldElementExpression::Identifier("a".into()),
FieldElementExpression::Identifier("b".into()),
));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_true),
.fold_boolean_expression(e_constant_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_false),
.fold_boolean_expression(e_constant_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_unchanged.clone()),
Ok(e_identifier_unchanged)
);
}
#[test]
fn bool_eq() {
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(false),
box BooleanExpression::Value(false)
)),
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(false),
BooleanExpression::Value(false)
))),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(true),
box BooleanExpression::Value(true)
)),
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(true),
BooleanExpression::Value(true)
))),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(true),
box BooleanExpression::Value(false)
)),
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(true),
BooleanExpression::Value(false)
))),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::<Bn128Field>::with_constants(&mut Constants::new())
.fold_boolean_expression(BooleanExpression::BoolEq(
box BooleanExpression::Value(false),
box BooleanExpression::Value(true)
)),
.fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new(
BooleanExpression::Value(false),
BooleanExpression::Value(true)
))),
Ok(BooleanExpression::Value(false))
);
}
#[test]
fn array_eq() {
let e_constant_true = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32),
));
let e_constant_false = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(4usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32),
));
let e_identifier_true: BooleanExpression<Bn128Field> =
BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32),
));
let e_identifier_unchanged: BooleanExpression<Bn128Field> =
BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Identifier("a".into())
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Identifier("b".into())
.annotate(Type::FieldElement, 1u32),
));
let e_non_canonical_true = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Spread(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
)
.into()]
.into(),
)
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32),
));
let e_non_canonical_false = BooleanExpression::ArrayEq(EqExpression::new(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Spread(
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(2usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32)
.into(),
)
.into()]
.into(),
)
.annotate(Type::FieldElement, 1u32),
ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(
FieldElementExpression::Number(Bn128Field::from(4usize)).into(),
)]
.into(),
)
.annotate(Type::FieldElement, 1u32),
));
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_constant_false),
Ok(BooleanExpression::Value(false))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_identifier_unchanged.clone()),
Ok(e_identifier_unchanged)
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_non_canonical_true),
Ok(BooleanExpression::Value(true))
);
assert_eq!(
Propagator::with_constants(&mut Constants::new())
.fold_boolean_expression(e_non_canonical_false),
Ok(BooleanExpression::Value(false))
);
}

View file

@ -56,10 +56,10 @@ impl<'ast> VariableWriteRemover {
.map(|i| match inner_ty {
Type::Int => unreachable!(),
Type::Array(..) => ArrayExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
),
BooleanExpression::UintEq(EqExpression::new(
i.into(),
head.clone(),
)),
match Self::choose_many(
ArrayExpression::select(base.clone(), i).into(),
tail.clone(),
@ -77,10 +77,10 @@ impl<'ast> VariableWriteRemover {
)
.into(),
Type::Struct(..) => StructExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
),
BooleanExpression::UintEq(EqExpression::new(
i.into(),
head.clone(),
)),
match Self::choose_many(
StructExpression::select(base.clone(), i).into(),
tail.clone(),
@ -98,10 +98,10 @@ impl<'ast> VariableWriteRemover {
)
.into(),
Type::FieldElement => FieldElementExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
),
BooleanExpression::UintEq(EqExpression::new(
i.into(),
head.clone(),
)),
match Self::choose_many(
FieldElementExpression::select(base.clone(), i)
.into(),
@ -120,10 +120,10 @@ impl<'ast> VariableWriteRemover {
)
.into(),
Type::Boolean => BooleanExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
),
BooleanExpression::UintEq(EqExpression::new(
i.into(),
head.clone(),
)),
match Self::choose_many(
BooleanExpression::select(base.clone(), i).into(),
tail.clone(),
@ -141,10 +141,10 @@ impl<'ast> VariableWriteRemover {
)
.into(),
Type::Uint(..) => UExpression::conditional(
BooleanExpression::UintEq(
box i.into(),
box head.clone(),
),
BooleanExpression::UintEq(EqExpression::new(
i.into(),
head.clone(),
)),
match Self::choose_many(
UExpression::select(base.clone(), i).into(),
tail.clone(),

View file

@ -315,6 +315,13 @@ pub trait Folder<'ast, T: Field>: Sized {
fold_member_expression(self, ty, e)
}
fn fold_eq_expression<E: Expr<'ast, T> + PartialEq + Constant + Fold<'ast, T>>(
&mut self,
e: EqExpression<E>,
) -> EqOrBoolean<'ast, T, E> {
fold_eq_expression(self, e)
}
fn fold_function_call_expression<
E: Id<'ast, T> + From<TypedExpression<'ast, T>> + Expr<'ast, T> + FunctionCall<'ast, T>,
>(
@ -684,6 +691,13 @@ pub fn fold_member_expression<
))
}
pub fn fold_eq_expression<'ast, T: Field, E: Fold<'ast, T>, F: Folder<'ast, T>>(
f: &mut F,
e: EqExpression<E>,
) -> EqOrBoolean<'ast, T, E> {
EqOrBoolean::Eq(EqExpression::new(e.left.fold(f), e.right.fold(f)))
}
pub fn fold_select_expression<
'ast,
T: Field,
@ -715,31 +729,26 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
BooleanExpression::Block(block) => BooleanExpression::Block(f.fold_block_expression(block)),
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)),
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);
BooleanExpression::FieldEq(box e1, box e2)
}
BooleanExpression::BoolEq(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1);
let e2 = f.fold_boolean_expression(e2);
BooleanExpression::BoolEq(box e1, box e2)
}
BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1);
let e2 = f.fold_array_expression(e2);
BooleanExpression::ArrayEq(box e1, box e2)
}
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1);
let e2 = f.fold_struct_expression(e2);
BooleanExpression::StructEq(box e1, box e2)
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1);
let e2 = f.fold_uint_expression(e2);
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::FieldEq(e) => match f.fold_eq_expression(e) {
EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::BoolEq(e) => match f.fold_eq_expression(e) {
EqOrBoolean::Eq(e) => BooleanExpression::BoolEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::ArrayEq(e) => match f.fold_eq_expression(e) {
EqOrBoolean::Eq(e) => BooleanExpression::ArrayEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::StructEq(e) => match f.fold_eq_expression(e) {
EqOrBoolean::Eq(e) => BooleanExpression::StructEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::UintEq(e) => match f.fold_eq_expression(e) {
EqOrBoolean::Eq(e) => BooleanExpression::UintEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1);
let e2 = f.fold_field_expression(e2);

View file

@ -878,6 +878,28 @@ impl<'ast, T> TypedExpressionListInner<'ast, T> {
TypedExpressionList { inner: self, types }
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct EqExpression<E> {
pub left: Box<E>,
pub right: Box<E>,
}
impl<E> EqExpression<E> {
pub fn new(left: E, right: E) -> Self {
EqExpression {
left: box left,
right: box right,
}
}
}
impl<E: fmt::Display> fmt::Display for EqExpression<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} == {}", self.left, self.right)
}
}
#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct BlockExpression<'ast, T, E> {
pub statements: Vec<TypedStatement<'ast, T>>,
@ -1141,20 +1163,11 @@ pub enum BooleanExpression<'ast, T> {
UintLe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGe(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
UintGt(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
FieldEq(
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
BoolEq(
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
),
ArrayEq(Box<ArrayExpression<'ast, T>>, Box<ArrayExpression<'ast, T>>),
StructEq(
Box<StructExpression<'ast, T>>,
Box<StructExpression<'ast, T>>,
),
UintEq(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
FieldEq(EqExpression<FieldElementExpression<'ast, T>>),
BoolEq(EqExpression<BooleanExpression<'ast, T>>),
ArrayEq(EqExpression<ArrayExpression<'ast, T>>),
StructEq(EqExpression<StructExpression<'ast, T>>),
UintEq(EqExpression<UExpression<'ast, T>>),
Or(
Box<BooleanExpression<'ast, T>>,
Box<BooleanExpression<'ast, T>>,
@ -1205,7 +1218,7 @@ impl<'ast, T> IntoIterator for ArrayValue<'ast, T> {
}
}
impl<'ast, T: Clone> ArrayValue<'ast, T> {
impl<'ast, T: Field> ArrayValue<'ast, T> {
fn expression_at_aux<
U: Select<'ast, T> + From<TypedExpression<'ast, T>> + Into<TypedExpression<'ast, T>>,
>(
@ -1529,11 +1542,11 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
BooleanExpression::UintLe(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs),
BooleanExpression::UintGe(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs),
BooleanExpression::UintGt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs),
BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::ArrayEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::StructEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::UintEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs),
BooleanExpression::FieldEq(ref e) => write!(f, "{}", e),
BooleanExpression::BoolEq(ref e) => write!(f, "{}", e),
BooleanExpression::ArrayEq(ref e) => write!(f, "{}", e),
BooleanExpression::StructEq(ref e) => write!(f, "{}", e),
BooleanExpression::UintEq(ref e) => write!(f, "{}", e),
BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs),
BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs),
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
@ -1625,7 +1638,7 @@ impl<'ast, T: Field> From<Variable<'ast, T>> for TypedExpression<'ast, T> {
// Common behaviour across expressions
pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
pub trait Expr<'ast, T>: fmt::Display + From<TypedExpression<'ast, T>> {
type Inner;
type Ty: Clone + IntoTypes<'ast, T>;
@ -1638,7 +1651,7 @@ pub trait Expr<'ast, T>: From<TypedExpression<'ast, T>> {
fn as_inner_mut(&mut self) -> &mut Self::Inner;
}
impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for FieldElementExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
@ -1659,7 +1672,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for FieldElementExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for BooleanExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
@ -1680,7 +1693,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for BooleanExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for UExpression<'ast, T> {
type Inner = UExpressionInner<'ast, T>;
type Ty = UBitwidth;
@ -1701,7 +1714,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for UExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for StructExpression<'ast, T> {
type Inner = StructExpressionInner<'ast, T>;
type Ty = StructType<'ast, T>;
@ -1722,7 +1735,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for StructExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for ArrayExpression<'ast, T> {
type Inner = ArrayExpressionInner<'ast, T>;
type Ty = ArrayType<'ast, T>;
@ -1743,7 +1756,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for ArrayExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for IntExpression<'ast, T> {
type Inner = Self;
type Ty = Type<'ast, T>;
@ -1764,7 +1777,7 @@ impl<'ast, T: Clone> Expr<'ast, T> for IntExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Expr<'ast, T> for TypedExpressionList<'ast, T> {
impl<'ast, T: Field> Expr<'ast, T> for TypedExpressionList<'ast, T> {
type Inner = TypedExpressionListInner<'ast, T>;
type Ty = Types<'ast, T>;
@ -1796,6 +1809,11 @@ pub enum SelectOrExpression<'ast, T, E: Expr<'ast, T>> {
Expression(E::Inner),
}
pub enum EqOrBoolean<'ast, T, E> {
Eq(EqExpression<E>),
Boolean(BooleanExpression<'ast, T>),
}
pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> {
Member(MemberExpression<'ast, T, E>),
Expression(E::Inner),
@ -1941,7 +1959,7 @@ impl<'ast, T> Select<'ast, T> for BooleanExpression<'ast, T> {
}
}
impl<'ast, T: Clone> Select<'ast, T> for TypedExpression<'ast, T> {
impl<'ast, T: Field> Select<'ast, T> for TypedExpression<'ast, T> {
fn select<I: Into<UExpression<'ast, T>>>(array: ArrayExpression<'ast, T>, index: I) -> Self {
match *array.ty().ty {
Type::Array(..) => ArrayExpression::select(array, index).into(),

View file

@ -218,6 +218,15 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fold_member_expression(self, ty, e)
}
fn fold_eq_expression<
E: Expr<'ast, T> + Typed<'ast, T> + PartialEq + Constant + ResultFold<'ast, T>,
>(
&mut self,
e: EqExpression<E>,
) -> Result<EqOrBoolean<'ast, T, E>, Self::Error> {
fold_eq_expression(self, e)
}
fn fold_select_expression<
E: Expr<'ast, T>
+ Select<'ast, T>
@ -736,6 +745,16 @@ pub fn fold_member_expression<
)))
}
pub fn fold_eq_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFolder<'ast, T>>(
f: &mut F,
e: EqExpression<E>,
) -> Result<EqOrBoolean<'ast, T, E>, F::Error> {
Ok(EqOrBoolean::Eq(EqExpression::new(
e.left.fold(f)?,
e.right.fold(f)?,
)))
}
pub fn fold_select_expression<
'ast,
T: Field,
@ -788,31 +807,26 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
}
BooleanExpression::Value(v) => BooleanExpression::Value(v),
BooleanExpression::Identifier(id) => BooleanExpression::Identifier(f.fold_name(id)?),
BooleanExpression::FieldEq(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;
BooleanExpression::FieldEq(box e1, box e2)
}
BooleanExpression::BoolEq(box e1, box e2) => {
let e1 = f.fold_boolean_expression(e1)?;
let e2 = f.fold_boolean_expression(e2)?;
BooleanExpression::BoolEq(box e1, box e2)
}
BooleanExpression::ArrayEq(box e1, box e2) => {
let e1 = f.fold_array_expression(e1)?;
let e2 = f.fold_array_expression(e2)?;
BooleanExpression::ArrayEq(box e1, box e2)
}
BooleanExpression::StructEq(box e1, box e2) => {
let e1 = f.fold_struct_expression(e1)?;
let e2 = f.fold_struct_expression(e2)?;
BooleanExpression::StructEq(box e1, box e2)
}
BooleanExpression::UintEq(box e1, box e2) => {
let e1 = f.fold_uint_expression(e1)?;
let e2 = f.fold_uint_expression(e2)?;
BooleanExpression::UintEq(box e1, box e2)
}
BooleanExpression::FieldEq(e) => match f.fold_eq_expression(e)? {
EqOrBoolean::Eq(e) => BooleanExpression::FieldEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::BoolEq(e) => match f.fold_eq_expression(e)? {
EqOrBoolean::Eq(e) => BooleanExpression::BoolEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::ArrayEq(e) => match f.fold_eq_expression(e)? {
EqOrBoolean::Eq(e) => BooleanExpression::ArrayEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::StructEq(e) => match f.fold_eq_expression(e)? {
EqOrBoolean::Eq(e) => BooleanExpression::StructEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::UintEq(e) => match f.fold_eq_expression(e)? {
EqOrBoolean::Eq(e) => BooleanExpression::UintEq(e),
EqOrBoolean::Boolean(u) => u,
},
BooleanExpression::FieldLt(box e1, box e2) => {
let e1 = f.fold_field_expression(e1)?;
let e2 = f.fold_field_expression(e2)?;