Merge pull request #905 from Zokrates/if-else-expression
Introduce if-else expression
This commit is contained in:
commit
b372f9fd7b
13 changed files with 427 additions and 461 deletions
1
changelogs/unreleased/905-schaeff
Normal file
1
changelogs/unreleased/905-schaeff
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve propagation on if-else expressions when consequence and alternative are equal
|
|
@ -1,10 +0,0 @@
|
||||||
def throwing_bound<N>(u32 x) -> u32:
|
|
||||||
assert(x == N)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
// this should compile: the conditional, even though it can throw, has a constant compile-time value `1`
|
|
||||||
// the value of the blocks should be propagated out, so that `if x == 0 then 1 else 1 fi` can be determined to be `1`
|
|
||||||
def main(u32 x):
|
|
||||||
for u32 i in 0..if x == 0 then throwing_bound::<0>(x) else throwing_bound::<1>(x) fi do
|
|
||||||
endfor
|
|
||||||
return
|
|
|
@ -1988,26 +1988,22 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
TypedExpression::Boolean(condition) => {
|
TypedExpression::Boolean(condition) => {
|
||||||
match (consequence_checked, alternative_checked) {
|
match (consequence_checked, alternative_checked) {
|
||||||
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
|
(TypedExpression::FieldElement(consequence), TypedExpression::FieldElement(alternative)) => {
|
||||||
Ok(FieldElementExpression::IfElse(box condition, box consequence, box alternative).into())
|
Ok(FieldElementExpression::if_else(condition, consequence, alternative).into())
|
||||||
},
|
},
|
||||||
(TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => {
|
(TypedExpression::Boolean(consequence), TypedExpression::Boolean(alternative)) => {
|
||||||
Ok(BooleanExpression::IfElse(box condition, box consequence, box alternative).into())
|
Ok(BooleanExpression::if_else(condition, consequence, alternative).into())
|
||||||
},
|
},
|
||||||
(TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => {
|
(TypedExpression::Array(consequence), TypedExpression::Array(alternative)) => {
|
||||||
let inner_type = consequence.inner_type().clone();
|
Ok(ArrayExpression::if_else(condition, consequence, alternative).into())
|
||||||
let size = consequence.size();
|
|
||||||
Ok(ArrayExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(inner_type, size).into())
|
|
||||||
},
|
},
|
||||||
(TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => {
|
(TypedExpression::Struct(consequence), TypedExpression::Struct(alternative)) => {
|
||||||
let ty = consequence.ty().clone();
|
Ok(StructExpression::if_else(condition, consequence, alternative).into())
|
||||||
Ok(StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty).into())
|
|
||||||
},
|
},
|
||||||
(TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => {
|
(TypedExpression::Uint(consequence), TypedExpression::Uint(alternative)) => {
|
||||||
let bitwidth = consequence.bitwidth();
|
Ok(UExpression::if_else(condition, consequence, alternative).into())
|
||||||
Ok(UExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(bitwidth).into())
|
|
||||||
},
|
},
|
||||||
(TypedExpression::Int(consequence), TypedExpression::Int(alternative)) => {
|
(TypedExpression::Int(consequence), TypedExpression::Int(alternative)) => {
|
||||||
Ok(IntExpression::IfElse(box condition, box consequence, box alternative).into())
|
Ok(IntExpression::if_else(condition, consequence, alternative).into())
|
||||||
},
|
},
|
||||||
(c, a) => Err(ErrorInner {
|
(c, a) => Err(ErrorInner {
|
||||||
pos: Some(pos),
|
pos: Some(pos),
|
||||||
|
|
|
@ -17,86 +17,17 @@ impl Isolator {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for Isolator {
|
impl<'ast, T: Field> Folder<'ast, T> for Isolator {
|
||||||
fn fold_field_expression(
|
fn fold_if_else_expression<
|
||||||
|
E: Expr<'ast, T> + Block<'ast, T> + Fold<'ast, T> + IfElse<'ast, T>,
|
||||||
|
>(
|
||||||
&mut self,
|
&mut self,
|
||||||
e: FieldElementExpression<'ast, T>,
|
_: &E::Ty,
|
||||||
) -> FieldElementExpression<'ast, T> {
|
e: IfElseExpression<'ast, T, E>,
|
||||||
match e {
|
) -> IfElseOrExpression<'ast, T, E> {
|
||||||
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
|
IfElseOrExpression::IfElse(IfElseExpression::new(
|
||||||
FieldElementExpression::IfElse(
|
self.fold_boolean_expression(*e.condition),
|
||||||
box self.fold_boolean_expression(condition),
|
E::block(vec![], e.consequence.fold(self)),
|
||||||
box FieldElementExpression::block(vec![], consequence.fold(self)),
|
E::block(vec![], e.alternative.fold(self)),
|
||||||
box FieldElementExpression::block(vec![], alternative.fold(self)),
|
))
|
||||||
)
|
|
||||||
}
|
|
||||||
e => fold_field_expression(self, e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fold_boolean_expression(
|
|
||||||
&mut self,
|
|
||||||
e: BooleanExpression<'ast, T>,
|
|
||||||
) -> BooleanExpression<'ast, T> {
|
|
||||||
match e {
|
|
||||||
BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
BooleanExpression::IfElse(
|
|
||||||
box self.fold_boolean_expression(condition),
|
|
||||||
box BooleanExpression::block(vec![], consequence.fold(self)),
|
|
||||||
box BooleanExpression::block(vec![], alternative.fold(self)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
e => fold_boolean_expression(self, e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fold_uint_expression_inner(
|
|
||||||
&mut self,
|
|
||||||
bitwidth: UBitwidth,
|
|
||||||
e: UExpressionInner<'ast, T>,
|
|
||||||
) -> UExpressionInner<'ast, T> {
|
|
||||||
match e {
|
|
||||||
UExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
UExpressionInner::IfElse(
|
|
||||||
box self.fold_boolean_expression(condition),
|
|
||||||
box UExpression::block(vec![], consequence.fold(self)),
|
|
||||||
box UExpression::block(vec![], alternative.fold(self)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
e => fold_uint_expression_inner(self, bitwidth, e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fold_array_expression_inner(
|
|
||||||
&mut self,
|
|
||||||
array_ty: &ArrayType<'ast, T>,
|
|
||||||
e: ArrayExpressionInner<'ast, T>,
|
|
||||||
) -> ArrayExpressionInner<'ast, T> {
|
|
||||||
match e {
|
|
||||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
ArrayExpressionInner::IfElse(
|
|
||||||
box self.fold_boolean_expression(condition),
|
|
||||||
box ArrayExpression::block(vec![], consequence.fold(self)),
|
|
||||||
box ArrayExpression::block(vec![], alternative.fold(self)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
e => fold_array_expression_inner(self, array_ty, e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fold_struct_expression_inner(
|
|
||||||
&mut self,
|
|
||||||
struct_ty: &StructType<'ast, T>,
|
|
||||||
e: StructExpressionInner<'ast, T>,
|
|
||||||
) -> StructExpressionInner<'ast, T> {
|
|
||||||
match e {
|
|
||||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
StructExpressionInner::IfElse(
|
|
||||||
box self.fold_boolean_expression(condition),
|
|
||||||
box StructExpression::block(vec![], consequence.fold(self)),
|
|
||||||
box StructExpression::block(vec![], alternative.fold(self)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
e => fold_struct_expression_inner(self, struct_ty, e),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,64 @@ fn flatten_identifier_rec<'ast>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Flatten<'ast, T: Field> {
|
||||||
|
fn flatten(
|
||||||
|
self,
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::FieldElementExpression<'ast, T> {
|
||||||
|
fn flatten(
|
||||||
|
self,
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
vec![f.fold_field_expression(statements_buffer, self).into()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::BooleanExpression<'ast, T> {
|
||||||
|
fn flatten(
|
||||||
|
self,
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
vec![f.fold_boolean_expression(statements_buffer, self).into()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::UExpression<'ast, T> {
|
||||||
|
fn flatten(
|
||||||
|
self,
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
vec![f.fold_uint_expression(statements_buffer, self).into()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::ArrayExpression<'ast, T> {
|
||||||
|
fn flatten(
|
||||||
|
self,
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
f.fold_array_expression(statements_buffer, self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: Field> Flatten<'ast, T> for typed_absy::StructExpression<'ast, T> {
|
||||||
|
fn flatten(
|
||||||
|
self,
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
f.fold_struct_expression(statements_buffer, self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Flattener<T> {
|
impl<'ast, T: Field> Flattener<T> {
|
||||||
pub fn flatten(p: typed_absy::TypedProgram<T>) -> zir::ZirProgram<T> {
|
pub fn flatten(p: typed_absy::TypedProgram<T>) -> zir::ZirProgram<T> {
|
||||||
let mut f = Flattener::default();
|
let mut f = Flattener::default();
|
||||||
|
@ -223,7 +281,15 @@ impl<'ast, T: Field> Flattener<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_member_expression<E>(
|
fn fold_if_else_expression<E: Flatten<'ast, T>>(
|
||||||
|
&mut self,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
c: typed_absy::IfElseExpression<'ast, T, E>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
fold_if_else_expression(self, statements_buffer, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_member_expression<E>(
|
||||||
&mut self,
|
&mut self,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
m: typed_absy::MemberExpression<'ast, T, E>,
|
m: typed_absy::MemberExpression<'ast, T, E>,
|
||||||
|
@ -231,7 +297,7 @@ impl<'ast, T: Field> Flattener<T> {
|
||||||
fold_member_expression(self, statements_buffer, m)
|
fold_member_expression(self, statements_buffer, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_select_expression<E>(
|
fn fold_select_expression<E>(
|
||||||
&mut self,
|
&mut self,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
select: typed_absy::SelectExpression<'ast, T, E>,
|
select: typed_absy::SelectExpression<'ast, T, E>,
|
||||||
|
@ -289,7 +355,7 @@ impl<'ast, T: Field> Flattener<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_statement<'ast, T: Field>(
|
fn fold_statement<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
s: typed_absy::TypedStatement<'ast, T>,
|
s: typed_absy::TypedStatement<'ast, T>,
|
||||||
|
@ -334,7 +400,7 @@ pub fn fold_statement<'ast, T: Field>(
|
||||||
statements_buffer.extend(res);
|
statements_buffer.extend(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_array_expression_inner<'ast, T: Field>(
|
fn fold_array_expression_inner<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
ty: &typed_absy::types::ConcreteType,
|
ty: &typed_absy::types::ConcreteType,
|
||||||
|
@ -376,44 +442,8 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
||||||
exprs
|
exprs
|
||||||
}
|
}
|
||||||
typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(),
|
typed_absy::ArrayExpressionInner::FunctionCall(..) => unreachable!(),
|
||||||
typed_absy::ArrayExpressionInner::IfElse(
|
typed_absy::ArrayExpressionInner::IfElse(c) => {
|
||||||
box condition,
|
f.fold_if_else_expression(statements_buffer, c)
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
) => {
|
|
||||||
let mut consequence_statements = vec![];
|
|
||||||
let mut alternative_statements = vec![];
|
|
||||||
|
|
||||||
let condition = f.fold_boolean_expression(statements_buffer, condition);
|
|
||||||
let consequence = f.fold_array_expression(&mut consequence_statements, consequence);
|
|
||||||
let alternative = f.fold_array_expression(&mut alternative_statements, alternative);
|
|
||||||
|
|
||||||
assert_eq!(consequence.len(), alternative.len());
|
|
||||||
|
|
||||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
|
||||||
condition.clone(),
|
|
||||||
consequence_statements,
|
|
||||||
alternative_statements,
|
|
||||||
));
|
|
||||||
|
|
||||||
use crate::zir::IfElse;
|
|
||||||
|
|
||||||
consequence
|
|
||||||
.into_iter()
|
|
||||||
.zip(alternative.into_iter())
|
|
||||||
.map(|(c, a)| match (c, a) {
|
|
||||||
(zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => {
|
|
||||||
zir::FieldElementExpression::if_else(condition.clone(), c, a).into()
|
|
||||||
}
|
|
||||||
(zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => {
|
|
||||||
zir::BooleanExpression::if_else(condition.clone(), c, a).into()
|
|
||||||
}
|
|
||||||
(zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => {
|
|
||||||
zir::UExpression::if_else(condition.clone(), c, a).into()
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
typed_absy::ArrayExpressionInner::Member(m) => {
|
typed_absy::ArrayExpressionInner::Member(m) => {
|
||||||
f.fold_member_expression(statements_buffer, m)
|
f.fold_member_expression(statements_buffer, m)
|
||||||
|
@ -452,7 +482,7 @@ pub fn fold_array_expression_inner<'ast, T: Field>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_struct_expression_inner<'ast, T: Field>(
|
fn fold_struct_expression_inner<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
ty: &typed_absy::types::ConcreteStructType,
|
ty: &typed_absy::types::ConcreteStructType,
|
||||||
|
@ -487,44 +517,8 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
|
||||||
.flat_map(|e| f.fold_expression(statements_buffer, e))
|
.flat_map(|e| f.fold_expression(statements_buffer, e))
|
||||||
.collect(),
|
.collect(),
|
||||||
typed_absy::StructExpressionInner::FunctionCall(..) => unreachable!(),
|
typed_absy::StructExpressionInner::FunctionCall(..) => unreachable!(),
|
||||||
typed_absy::StructExpressionInner::IfElse(
|
typed_absy::StructExpressionInner::IfElse(c) => {
|
||||||
box condition,
|
f.fold_if_else_expression(statements_buffer, c)
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
) => {
|
|
||||||
let mut consequence_statements = vec![];
|
|
||||||
let mut alternative_statements = vec![];
|
|
||||||
|
|
||||||
let condition = f.fold_boolean_expression(statements_buffer, condition);
|
|
||||||
let consequence = f.fold_struct_expression(&mut consequence_statements, consequence);
|
|
||||||
let alternative = f.fold_struct_expression(&mut alternative_statements, alternative);
|
|
||||||
|
|
||||||
assert_eq!(consequence.len(), alternative.len());
|
|
||||||
|
|
||||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
|
||||||
condition.clone(),
|
|
||||||
consequence_statements,
|
|
||||||
alternative_statements,
|
|
||||||
));
|
|
||||||
|
|
||||||
use zir::IfElse;
|
|
||||||
|
|
||||||
consequence
|
|
||||||
.into_iter()
|
|
||||||
.zip(alternative.into_iter())
|
|
||||||
.map(|(c, a)| match (c, a) {
|
|
||||||
(zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => {
|
|
||||||
zir::FieldElementExpression::if_else(condition.clone(), c, a).into()
|
|
||||||
}
|
|
||||||
(zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => {
|
|
||||||
zir::BooleanExpression::if_else(condition.clone(), c, a).into()
|
|
||||||
}
|
|
||||||
(zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => {
|
|
||||||
zir::UExpression::if_else(condition.clone(), c, a).into()
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
typed_absy::StructExpressionInner::Member(m) => {
|
typed_absy::StructExpressionInner::Member(m) => {
|
||||||
f.fold_member_expression(statements_buffer, m)
|
f.fold_member_expression(statements_buffer, m)
|
||||||
|
@ -535,7 +529,7 @@ pub fn fold_struct_expression_inner<'ast, T: Field>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_member_expression<'ast, T: Field, E>(
|
fn fold_member_expression<'ast, T: Field, E>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
m: typed_absy::MemberExpression<'ast, T, E>,
|
m: typed_absy::MemberExpression<'ast, T, E>,
|
||||||
|
@ -571,7 +565,7 @@ pub fn fold_member_expression<'ast, T: Field, E>(
|
||||||
s[offset..offset + size].to_vec()
|
s[offset..offset + size].to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_select_expression<'ast, T: Field, E>(
|
fn fold_select_expression<'ast, T: Field, E>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
select: typed_absy::SelectExpression<'ast, T, E>,
|
select: typed_absy::SelectExpression<'ast, T, E>,
|
||||||
|
@ -593,7 +587,47 @@ pub fn fold_select_expression<'ast, T: Field, E>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_field_expression<'ast, T: Field>(
|
fn fold_if_else_expression<'ast, T: Field, E: Flatten<'ast, T>>(
|
||||||
|
f: &mut Flattener<T>,
|
||||||
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
|
c: typed_absy::IfElseExpression<'ast, T, E>,
|
||||||
|
) -> Vec<zir::ZirExpression<'ast, T>> {
|
||||||
|
let mut consequence_statements = vec![];
|
||||||
|
let mut alternative_statements = vec![];
|
||||||
|
|
||||||
|
let condition = f.fold_boolean_expression(statements_buffer, *c.condition);
|
||||||
|
let consequence = c.consequence.flatten(f, &mut consequence_statements);
|
||||||
|
let alternative = c.alternative.flatten(f, &mut alternative_statements);
|
||||||
|
|
||||||
|
assert_eq!(consequence.len(), alternative.len());
|
||||||
|
|
||||||
|
statements_buffer.push(zir::ZirStatement::IfElse(
|
||||||
|
condition.clone(),
|
||||||
|
consequence_statements,
|
||||||
|
alternative_statements,
|
||||||
|
));
|
||||||
|
|
||||||
|
use crate::zir::IfElse;
|
||||||
|
|
||||||
|
consequence
|
||||||
|
.into_iter()
|
||||||
|
.zip(alternative.into_iter())
|
||||||
|
.map(|(c, a)| match (c, a) {
|
||||||
|
(zir::ZirExpression::FieldElement(c), zir::ZirExpression::FieldElement(a)) => {
|
||||||
|
zir::FieldElementExpression::if_else(condition.clone(), c, a).into()
|
||||||
|
}
|
||||||
|
(zir::ZirExpression::Boolean(c), zir::ZirExpression::Boolean(a)) => {
|
||||||
|
zir::BooleanExpression::if_else(condition.clone(), c, a).into()
|
||||||
|
}
|
||||||
|
(zir::ZirExpression::Uint(c), zir::ZirExpression::Uint(a)) => {
|
||||||
|
zir::UExpression::if_else(condition.clone(), c, a).into()
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_field_expression<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
e: typed_absy::FieldElementExpression<'ast, T>,
|
e: typed_absy::FieldElementExpression<'ast, T>,
|
||||||
|
@ -647,26 +681,12 @@ pub fn fold_field_expression<'ast, T: Field>(
|
||||||
typed_absy::FieldElementExpression::Pos(box e) => {
|
typed_absy::FieldElementExpression::Pos(box e) => {
|
||||||
f.fold_field_expression(statements_buffer, e)
|
f.fold_field_expression(statements_buffer, e)
|
||||||
}
|
}
|
||||||
typed_absy::FieldElementExpression::IfElse(
|
typed_absy::FieldElementExpression::IfElse(c) => f
|
||||||
box condition,
|
.fold_if_else_expression(statements_buffer, c)
|
||||||
box consequence,
|
.pop()
|
||||||
box alternative,
|
.unwrap()
|
||||||
) => {
|
.try_into()
|
||||||
let mut consequence_statements = vec![];
|
.unwrap(),
|
||||||
let mut alternative_statements = vec![];
|
|
||||||
|
|
||||||
let condition = f.fold_boolean_expression(statements_buffer, condition);
|
|
||||||
let consequence = f.fold_field_expression(&mut consequence_statements, consequence);
|
|
||||||
let alternative = f.fold_field_expression(&mut alternative_statements, alternative);
|
|
||||||
|
|
||||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
|
||||||
condition.clone(),
|
|
||||||
consequence_statements,
|
|
||||||
alternative_statements,
|
|
||||||
));
|
|
||||||
|
|
||||||
zir::FieldElementExpression::IfElse(box condition, box consequence, box alternative)
|
|
||||||
}
|
|
||||||
typed_absy::FieldElementExpression::FunctionCall(..) => unreachable!(""),
|
typed_absy::FieldElementExpression::FunctionCall(..) => unreachable!(""),
|
||||||
typed_absy::FieldElementExpression::Select(select) => f
|
typed_absy::FieldElementExpression::Select(select) => f
|
||||||
.fold_select_expression(statements_buffer, select)
|
.fold_select_expression(statements_buffer, select)
|
||||||
|
@ -690,7 +710,7 @@ pub fn fold_field_expression<'ast, T: Field>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_boolean_expression<'ast, T: Field>(
|
fn fold_boolean_expression<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
e: typed_absy::BooleanExpression<'ast, T>,
|
e: typed_absy::BooleanExpression<'ast, T>,
|
||||||
|
@ -836,22 +856,12 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
||||||
let e = f.fold_boolean_expression(statements_buffer, e);
|
let e = f.fold_boolean_expression(statements_buffer, e);
|
||||||
zir::BooleanExpression::Not(box e)
|
zir::BooleanExpression::Not(box e)
|
||||||
}
|
}
|
||||||
typed_absy::BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
|
typed_absy::BooleanExpression::IfElse(c) => f
|
||||||
let mut consequence_statements = vec![];
|
.fold_if_else_expression(statements_buffer, c)
|
||||||
let mut alternative_statements = vec![];
|
.pop()
|
||||||
|
.unwrap()
|
||||||
let condition = f.fold_boolean_expression(statements_buffer, condition);
|
.try_into()
|
||||||
let consequence = f.fold_boolean_expression(&mut consequence_statements, consequence);
|
.unwrap(),
|
||||||
let alternative = f.fold_boolean_expression(&mut alternative_statements, alternative);
|
|
||||||
|
|
||||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
|
||||||
condition.clone(),
|
|
||||||
consequence_statements,
|
|
||||||
alternative_statements,
|
|
||||||
));
|
|
||||||
|
|
||||||
zir::BooleanExpression::IfElse(box condition, box consequence, box alternative)
|
|
||||||
}
|
|
||||||
typed_absy::BooleanExpression::FunctionCall(..) => unreachable!(),
|
typed_absy::BooleanExpression::FunctionCall(..) => unreachable!(),
|
||||||
typed_absy::BooleanExpression::Select(select) => f
|
typed_absy::BooleanExpression::Select(select) => f
|
||||||
.fold_select_expression(statements_buffer, select)
|
.fold_select_expression(statements_buffer, select)
|
||||||
|
@ -868,7 +878,7 @@ pub fn fold_boolean_expression<'ast, T: Field>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_uint_expression<'ast, T: Field>(
|
fn fold_uint_expression<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
e: typed_absy::UExpression<'ast, T>,
|
e: typed_absy::UExpression<'ast, T>,
|
||||||
|
@ -877,7 +887,7 @@ pub fn fold_uint_expression<'ast, T: Field>(
|
||||||
.annotate(e.bitwidth.to_usize())
|
.annotate(e.bitwidth.to_usize())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_uint_expression_inner<'ast, T: Field>(
|
fn fold_uint_expression_inner<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
bitwidth: UBitwidth,
|
bitwidth: UBitwidth,
|
||||||
|
@ -1008,26 +1018,17 @@ pub fn fold_uint_expression_inner<'ast, T: Field>(
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_inner(),
|
.into_inner(),
|
||||||
typed_absy::UExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
typed_absy::UExpressionInner::IfElse(c) => zir::UExpression::try_from(
|
||||||
let mut consequence_statements = vec![];
|
f.fold_if_else_expression(statements_buffer, c)
|
||||||
let mut alternative_statements = vec![];
|
.pop()
|
||||||
|
.unwrap(),
|
||||||
let condition = f.fold_boolean_expression(statements_buffer, condition);
|
)
|
||||||
let consequence = f.fold_uint_expression(&mut consequence_statements, consequence);
|
.unwrap()
|
||||||
let alternative = f.fold_uint_expression(&mut alternative_statements, alternative);
|
.into_inner(),
|
||||||
|
|
||||||
statements_buffer.push(zir::ZirStatement::IfElse(
|
|
||||||
condition.clone(),
|
|
||||||
consequence_statements,
|
|
||||||
alternative_statements,
|
|
||||||
));
|
|
||||||
|
|
||||||
zir::UExpressionInner::IfElse(box condition, box consequence, box alternative)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_function<'ast, T: Field>(
|
fn fold_function<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
fun: typed_absy::TypedFunction<'ast, T>,
|
fun: typed_absy::TypedFunction<'ast, T>,
|
||||||
) -> zir::ZirFunction<'ast, T> {
|
) -> zir::ZirFunction<'ast, T> {
|
||||||
|
@ -1052,7 +1053,7 @@ pub fn fold_function<'ast, T: Field>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_array_expression<'ast, T: Field>(
|
fn fold_array_expression<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
e: typed_absy::ArrayExpression<'ast, T>,
|
e: typed_absy::ArrayExpression<'ast, T>,
|
||||||
|
@ -1069,7 +1070,7 @@ pub fn fold_array_expression<'ast, T: Field>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_struct_expression<'ast, T: Field>(
|
fn fold_struct_expression<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
statements_buffer: &mut Vec<zir::ZirStatement<'ast, T>>,
|
||||||
e: typed_absy::StructExpression<'ast, T>,
|
e: typed_absy::StructExpression<'ast, T>,
|
||||||
|
@ -1081,7 +1082,7 @@ pub fn fold_struct_expression<'ast, T: Field>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fold_program<'ast, T: Field>(
|
fn fold_program<'ast, T: Field>(
|
||||||
f: &mut Flattener<T>,
|
f: &mut Flattener<T>,
|
||||||
mut p: typed_absy::TypedProgram<'ast, T>,
|
mut p: typed_absy::TypedProgram<'ast, T>,
|
||||||
) -> zir::ZirProgram<'ast, T> {
|
) -> zir::ZirProgram<'ast, T> {
|
||||||
|
|
|
@ -290,6 +290,35 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
fold_function(self, f)
|
fold_function(self, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_if_else_expression<
|
||||||
|
E: Expr<'ast, T> + IfElse<'ast, T> + PartialEq + ResultFold<'ast, T>,
|
||||||
|
>(
|
||||||
|
&mut self,
|
||||||
|
_: &E::Ty,
|
||||||
|
e: IfElseExpression<'ast, T, E>,
|
||||||
|
) -> Result<IfElseOrExpression<'ast, T, E>, Self::Error> {
|
||||||
|
Ok(
|
||||||
|
match (
|
||||||
|
self.fold_boolean_expression(*e.condition)?,
|
||||||
|
e.consequence.fold(self)?,
|
||||||
|
e.alternative.fold(self)?,
|
||||||
|
) {
|
||||||
|
(BooleanExpression::Value(true), consequence, _) => {
|
||||||
|
IfElseOrExpression::Expression(consequence.into_inner())
|
||||||
|
}
|
||||||
|
(BooleanExpression::Value(false), _, alternative) => {
|
||||||
|
IfElseOrExpression::Expression(alternative.into_inner())
|
||||||
|
}
|
||||||
|
(_, consequence, alternative) if consequence == alternative => {
|
||||||
|
IfElseOrExpression::Expression(consequence.into_inner())
|
||||||
|
}
|
||||||
|
(condition, consequence, alternative) => IfElseOrExpression::IfElse(
|
||||||
|
IfElseExpression::new(condition, consequence, alternative),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_statement(
|
fn fold_statement(
|
||||||
&mut self,
|
&mut self,
|
||||||
s: TypedStatement<'ast, T>,
|
s: TypedStatement<'ast, T>,
|
||||||
|
@ -913,19 +942,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
box e2.annotate(bitwidth),
|
box e2.annotate(bitwidth),
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
UExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
let consequence = self.fold_uint_expression(consequence)?;
|
|
||||||
let alternative = self.fold_uint_expression(alternative)?;
|
|
||||||
match self.fold_boolean_expression(condition)? {
|
|
||||||
BooleanExpression::Value(true) => Ok(consequence.into_inner()),
|
|
||||||
BooleanExpression::Value(false) => Ok(alternative.into_inner()),
|
|
||||||
c => Ok(UExpressionInner::IfElse(
|
|
||||||
box c,
|
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
UExpressionInner::Not(box e) => {
|
UExpressionInner::Not(box e) => {
|
||||||
let e = self.fold_uint_expression(e)?.into_inner();
|
let e = self.fold_uint_expression(e)?.into_inner();
|
||||||
match e {
|
match e {
|
||||||
|
@ -1035,19 +1051,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
FieldElementExpression::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
let consequence = self.fold_field_expression(consequence)?;
|
|
||||||
let alternative = self.fold_field_expression(alternative)?;
|
|
||||||
match self.fold_boolean_expression(condition)? {
|
|
||||||
BooleanExpression::Value(true) => Ok(consequence),
|
|
||||||
BooleanExpression::Value(false) => Ok(alternative),
|
|
||||||
c => Ok(FieldElementExpression::IfElse(
|
|
||||||
box c,
|
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
e => fold_field_expression(self, e),
|
e => fold_field_expression(self, e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1167,19 +1170,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
},
|
},
|
||||||
None => Ok(ArrayExpressionInner::Identifier(id)),
|
None => Ok(ArrayExpressionInner::Identifier(id)),
|
||||||
},
|
},
|
||||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
let consequence = self.fold_array_expression(consequence)?;
|
|
||||||
let alternative = self.fold_array_expression(alternative)?;
|
|
||||||
match self.fold_boolean_expression(condition)? {
|
|
||||||
BooleanExpression::Value(true) => Ok(consequence.into_inner()),
|
|
||||||
BooleanExpression::Value(false) => Ok(alternative.into_inner()),
|
|
||||||
c => Ok(ArrayExpressionInner::IfElse(
|
|
||||||
box c,
|
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
e => fold_array_expression_inner(self, ty, e),
|
e => fold_array_expression_inner(self, ty, e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1197,19 +1187,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
},
|
},
|
||||||
None => Ok(StructExpressionInner::Identifier(id)),
|
None => Ok(StructExpressionInner::Identifier(id)),
|
||||||
},
|
},
|
||||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
let consequence = self.fold_struct_expression(consequence)?;
|
|
||||||
let alternative = self.fold_struct_expression(alternative)?;
|
|
||||||
match self.fold_boolean_expression(condition)? {
|
|
||||||
BooleanExpression::Value(true) => Ok(consequence.into_inner()),
|
|
||||||
BooleanExpression::Value(false) => Ok(alternative.into_inner()),
|
|
||||||
c => Ok(StructExpressionInner::IfElse(
|
|
||||||
box c,
|
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StructExpressionInner::Value(v) => {
|
StructExpressionInner::Value(v) => {
|
||||||
let v = v.into_iter().zip(ty.iter()).map(|(v, member)|
|
let v = v.into_iter().zip(ty.iter()).map(|(v, member)|
|
||||||
match self.fold_expression(v) {
|
match self.fold_expression(v) {
|
||||||
|
@ -1433,19 +1410,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
e => Ok(BooleanExpression::Not(box e)),
|
e => Ok(BooleanExpression::Not(box e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BooleanExpression::IfElse(box condition, box consequence, box alternative) => {
|
|
||||||
let consequence = self.fold_boolean_expression(consequence)?;
|
|
||||||
let alternative = self.fold_boolean_expression(alternative)?;
|
|
||||||
match self.fold_boolean_expression(condition)? {
|
|
||||||
BooleanExpression::Value(true) => Ok(consequence),
|
|
||||||
BooleanExpression::Value(false) => Ok(alternative),
|
|
||||||
c => Ok(BooleanExpression::IfElse(
|
|
||||||
box c,
|
|
||||||
box consequence,
|
|
||||||
box alternative,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
e => fold_boolean_expression(self, e),
|
e => fold_boolean_expression(self, e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1531,10 +1495,10 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn if_else_true() {
|
fn if_else_true() {
|
||||||
let e = FieldElementExpression::IfElse(
|
let e = FieldElementExpression::if_else(
|
||||||
box BooleanExpression::Value(true),
|
BooleanExpression::Value(true),
|
||||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
FieldElementExpression::Number(Bn128Field::from(3)),
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -1545,10 +1509,10 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn if_else_false() {
|
fn if_else_false() {
|
||||||
let e = FieldElementExpression::IfElse(
|
let e = FieldElementExpression::if_else(
|
||||||
box BooleanExpression::Value(false),
|
BooleanExpression::Value(false),
|
||||||
box FieldElementExpression::Number(Bn128Field::from(2)),
|
FieldElementExpression::Number(Bn128Field::from(2)),
|
||||||
box FieldElementExpression::Number(Bn128Field::from(3)),
|
FieldElementExpression::Number(Bn128Field::from(3)),
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|
|
@ -193,6 +193,20 @@ pub trait Folder<'ast, T: Field>: Sized {
|
||||||
fold_block_expression(self, block)
|
fold_block_expression(self, block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_if_else_expression<
|
||||||
|
E: Expr<'ast, T>
|
||||||
|
+ Fold<'ast, T>
|
||||||
|
+ Block<'ast, T>
|
||||||
|
+ IfElse<'ast, T>
|
||||||
|
+ From<TypedExpression<'ast, T>>,
|
||||||
|
>(
|
||||||
|
&mut self,
|
||||||
|
ty: &E::Ty,
|
||||||
|
e: IfElseExpression<'ast, T, E>,
|
||||||
|
) -> IfElseOrExpression<'ast, T, E> {
|
||||||
|
fold_if_else_expression(self, ty, e)
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_member_expression<
|
fn fold_member_expression<
|
||||||
E: Expr<'ast, T> + Member<'ast, T> + From<TypedExpression<'ast, T>>,
|
E: Expr<'ast, T> + Member<'ast, T> + From<TypedExpression<'ast, T>>,
|
||||||
>(
|
>(
|
||||||
|
@ -375,13 +389,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
FunctionCallOrExpression::Expression(u) => u,
|
FunctionCallOrExpression::Expression(u) => u,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) {
|
||||||
ArrayExpressionInner::IfElse(
|
IfElseOrExpression::IfElse(s) => ArrayExpressionInner::IfElse(s),
|
||||||
box f.fold_boolean_expression(condition),
|
IfElseOrExpression::Expression(u) => u,
|
||||||
box f.fold_array_expression(consequence),
|
},
|
||||||
box f.fold_array_expression(alternative),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
ArrayExpressionInner::Select(select) => match f.fold_select_expression(ty, select) {
|
ArrayExpressionInner::Select(select) => match f.fold_select_expression(ty, select) {
|
||||||
SelectOrExpression::Select(s) => ArrayExpressionInner::Select(s),
|
SelectOrExpression::Select(s) => ArrayExpressionInner::Select(s),
|
||||||
SelectOrExpression::Expression(u) => u,
|
SelectOrExpression::Expression(u) => u,
|
||||||
|
@ -425,13 +436,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
FunctionCallOrExpression::Expression(u) => u,
|
FunctionCallOrExpression::Expression(u) => u,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c) {
|
||||||
StructExpressionInner::IfElse(
|
IfElseOrExpression::IfElse(s) => StructExpressionInner::IfElse(s),
|
||||||
box f.fold_boolean_expression(condition),
|
IfElseOrExpression::Expression(u) => u,
|
||||||
box f.fold_struct_expression(consequence),
|
},
|
||||||
box f.fold_struct_expression(alternative),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
StructExpressionInner::Select(select) => match f.fold_select_expression(ty, select) {
|
StructExpressionInner::Select(select) => match f.fold_select_expression(ty, select) {
|
||||||
SelectOrExpression::Select(s) => StructExpressionInner::Select(s),
|
SelectOrExpression::Select(s) => StructExpressionInner::Select(s),
|
||||||
SelectOrExpression::Expression(u) => u,
|
SelectOrExpression::Expression(u) => u,
|
||||||
|
@ -490,11 +498,11 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
|
|
||||||
FieldElementExpression::Pos(box e)
|
FieldElementExpression::Pos(box e)
|
||||||
}
|
}
|
||||||
FieldElementExpression::IfElse(box cond, box cons, box alt) => {
|
FieldElementExpression::IfElse(c) => {
|
||||||
let cond = f.fold_boolean_expression(cond);
|
match f.fold_if_else_expression(&Type::FieldElement, c) {
|
||||||
let cons = f.fold_field_expression(cons);
|
IfElseOrExpression::IfElse(s) => FieldElementExpression::IfElse(s),
|
||||||
let alt = f.fold_field_expression(alt);
|
IfElseOrExpression::Expression(u) => u,
|
||||||
FieldElementExpression::IfElse(box cond, box cons, box alt)
|
}
|
||||||
}
|
}
|
||||||
FieldElementExpression::FunctionCall(function_call) => {
|
FieldElementExpression::FunctionCall(function_call) => {
|
||||||
match f.fold_function_call_expression(&Type::FieldElement, function_call) {
|
match f.fold_function_call_expression(&Type::FieldElement, function_call) {
|
||||||
|
@ -518,6 +526,23 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn fold_if_else_expression<
|
||||||
|
'ast,
|
||||||
|
T: Field,
|
||||||
|
E: Expr<'ast, T> + Fold<'ast, T> + IfElse<'ast, T> + From<TypedExpression<'ast, T>>,
|
||||||
|
F: Folder<'ast, T>,
|
||||||
|
>(
|
||||||
|
f: &mut F,
|
||||||
|
_: &E::Ty,
|
||||||
|
e: IfElseExpression<'ast, T, E>,
|
||||||
|
) -> IfElseOrExpression<'ast, T, E> {
|
||||||
|
IfElseOrExpression::IfElse(IfElseExpression::new(
|
||||||
|
f.fold_boolean_expression(*e.condition),
|
||||||
|
e.consequence.fold(f),
|
||||||
|
e.alternative.fold(f),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn fold_member_expression<
|
pub fn fold_member_expression<
|
||||||
'ast,
|
'ast,
|
||||||
T: Field,
|
T: Field,
|
||||||
|
@ -652,12 +677,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
FunctionCallOrExpression::Expression(u) => u,
|
FunctionCallOrExpression::Expression(u) => u,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BooleanExpression::IfElse(box cond, box cons, box alt) => {
|
BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c) {
|
||||||
let cond = f.fold_boolean_expression(cond);
|
IfElseOrExpression::IfElse(s) => BooleanExpression::IfElse(s),
|
||||||
let cons = f.fold_boolean_expression(cons);
|
IfElseOrExpression::Expression(u) => u,
|
||||||
let alt = f.fold_boolean_expression(alt);
|
},
|
||||||
BooleanExpression::IfElse(box cond, box cons, box alt)
|
|
||||||
}
|
|
||||||
BooleanExpression::Select(select) => match f.fold_select_expression(&Type::Boolean, select)
|
BooleanExpression::Select(select) => match f.fold_select_expression(&Type::Boolean, select)
|
||||||
{
|
{
|
||||||
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
|
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
|
||||||
|
@ -782,12 +805,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: Folder<'ast, T>>(
|
||||||
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
||||||
SelectOrExpression::Expression(u) => u,
|
SelectOrExpression::Expression(u) => u,
|
||||||
},
|
},
|
||||||
UExpressionInner::IfElse(box cond, box cons, box alt) => {
|
UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c) {
|
||||||
let cond = f.fold_boolean_expression(cond);
|
IfElseOrExpression::IfElse(s) => UExpressionInner::IfElse(s),
|
||||||
let cons = f.fold_uint_expression(cons);
|
IfElseOrExpression::Expression(u) => u,
|
||||||
let alt = f.fold_uint_expression(alt);
|
},
|
||||||
UExpressionInner::IfElse(box cond, box cons, box alt)
|
|
||||||
}
|
|
||||||
UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m) {
|
UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m) {
|
||||||
MemberOrExpression::Member(m) => UExpressionInner::Member(m),
|
MemberOrExpression::Member(m) => UExpressionInner::Member(m),
|
||||||
MemberOrExpression::Expression(u) => u,
|
MemberOrExpression::Expression(u) => u,
|
||||||
|
|
|
@ -2,8 +2,8 @@ use crate::typed_absy::types::{ArrayType, Type};
|
||||||
use crate::typed_absy::UBitwidth;
|
use crate::typed_absy::UBitwidth;
|
||||||
use crate::typed_absy::{
|
use crate::typed_absy::{
|
||||||
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
|
ArrayExpression, ArrayExpressionInner, BooleanExpression, FieldElementExpression, IfElse,
|
||||||
Select, SelectExpression, StructExpression, Typed, TypedExpression, TypedExpressionOrSpread,
|
IfElseExpression, Select, SelectExpression, StructExpression, Typed, TypedExpression,
|
||||||
TypedSpread, UExpression, UExpressionInner,
|
TypedExpressionOrSpread, TypedSpread, UExpression, UExpressionInner,
|
||||||
};
|
};
|
||||||
use num_bigint::BigUint;
|
use num_bigint::BigUint;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
|
@ -142,11 +142,7 @@ pub enum IntExpression<'ast, T> {
|
||||||
Div(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
Div(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||||
Rem(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
Rem(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||||
Pow(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
Pow(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||||
IfElse(
|
IfElse(IfElseExpression<'ast, T, IntExpression<'ast, T>>),
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<IntExpression<'ast, T>>,
|
|
||||||
Box<IntExpression<'ast, T>>,
|
|
||||||
),
|
|
||||||
Select(SelectExpression<'ast, T, IntExpression<'ast, T>>),
|
Select(SelectExpression<'ast, T, IntExpression<'ast, T>>),
|
||||||
Xor(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
Xor(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||||
And(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
And(Box<IntExpression<'ast, T>>, Box<IntExpression<'ast, T>>),
|
||||||
|
@ -261,11 +257,7 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> {
|
||||||
IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
IntExpression::RightShift(ref e, ref by) => write!(f, "({} >> {})", e, by),
|
||||||
IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
IntExpression::LeftShift(ref e, ref by) => write!(f, "({} << {})", e, by),
|
||||||
IntExpression::Not(ref e) => write!(f, "!{}", e),
|
IntExpression::Not(ref e) => write!(f, "!{}", e),
|
||||||
IntExpression::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
IntExpression::IfElse(ref c) => write!(f, "{}", c),
|
||||||
f,
|
|
||||||
"if {} then {} else {} fi",
|
|
||||||
condition, consequent, alternative
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -315,13 +307,11 @@ impl<'ast, T: Field> FieldElementExpression<'ast, T> {
|
||||||
)),
|
)),
|
||||||
IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)),
|
IntExpression::Pos(box e) => Ok(Self::Pos(box Self::try_from_int(e)?)),
|
||||||
IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)),
|
IntExpression::Neg(box e) => Ok(Self::Neg(box Self::try_from_int(e)?)),
|
||||||
IntExpression::IfElse(box condition, box consequence, box alternative) => {
|
IntExpression::IfElse(c) => Ok(Self::IfElse(IfElseExpression::new(
|
||||||
Ok(Self::IfElse(
|
*c.condition,
|
||||||
box condition,
|
Self::try_from_int(*c.consequence)?,
|
||||||
box Self::try_from_int(consequence)?,
|
Self::try_from_int(*c.alternative)?,
|
||||||
box Self::try_from_int(alternative)?,
|
))),
|
||||||
))
|
|
||||||
}
|
|
||||||
IntExpression::Select(select) => {
|
IntExpression::Select(select) => {
|
||||||
let array = *select.array;
|
let array = *select.array;
|
||||||
let index = *select.index;
|
let index = *select.index;
|
||||||
|
@ -430,10 +420,10 @@ impl<'ast, T: Field> UExpression<'ast, T> {
|
||||||
Self::try_from_int(e1, bitwidth)?,
|
Self::try_from_int(e1, bitwidth)?,
|
||||||
e2,
|
e2,
|
||||||
)),
|
)),
|
||||||
IfElse(box condition, box consequence, box alternative) => Ok(UExpression::if_else(
|
IfElse(c) => Ok(UExpression::if_else(
|
||||||
condition,
|
*c.condition,
|
||||||
Self::try_from_int(consequence, bitwidth)?,
|
Self::try_from_int(*c.consequence, bitwidth)?,
|
||||||
Self::try_from_int(alternative, bitwidth)?,
|
Self::try_from_int(*c.alternative, bitwidth)?,
|
||||||
)),
|
)),
|
||||||
Select(select) => {
|
Select(select) => {
|
||||||
let array = *select.array;
|
let array = *select.array;
|
||||||
|
|
|
@ -641,13 +641,7 @@ impl<'ast, T: fmt::Display> fmt::Display for StructExpression<'ast, T> {
|
||||||
StructExpressionInner::FunctionCall(ref function_call) => {
|
StructExpressionInner::FunctionCall(ref function_call) => {
|
||||||
write!(f, "{}", function_call)
|
write!(f, "{}", function_call)
|
||||||
}
|
}
|
||||||
StructExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => {
|
StructExpressionInner::IfElse(ref c) => write!(f, "{}", c),
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"if {} then {} else {} fi",
|
|
||||||
condition, consequent, alternative
|
|
||||||
)
|
|
||||||
}
|
|
||||||
StructExpressionInner::Member(ref m) => write!(f, "{}", m),
|
StructExpressionInner::Member(ref m) => write!(f, "{}", m),
|
||||||
StructExpressionInner::Select(ref select) => write!(f, "{}", select),
|
StructExpressionInner::Select(ref select) => write!(f, "{}", select),
|
||||||
}
|
}
|
||||||
|
@ -792,6 +786,33 @@ impl<'ast, T: fmt::Display, E> fmt::Display for SelectExpression<'ast, T, E> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||||
|
pub struct IfElseExpression<'ast, T, E> {
|
||||||
|
pub condition: Box<BooleanExpression<'ast, T>>,
|
||||||
|
pub consequence: Box<E>,
|
||||||
|
pub alternative: Box<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T, E> IfElseExpression<'ast, T, E> {
|
||||||
|
pub fn new(condition: BooleanExpression<'ast, T>, consequence: E, alternative: E) -> Self {
|
||||||
|
IfElseExpression {
|
||||||
|
condition: box condition,
|
||||||
|
consequence: box consequence,
|
||||||
|
alternative: box alternative,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: fmt::Display, E: fmt::Display> fmt::Display for IfElseExpression<'ast, T, E> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"if {} then {} else {} fi",
|
||||||
|
self.condition, self.consequence, self.alternative
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
#[derive(Clone, PartialEq, Debug, Hash, Eq)]
|
||||||
pub struct FunctionCallExpression<'ast, T, E> {
|
pub struct FunctionCallExpression<'ast, T, E> {
|
||||||
pub function_key: DeclarationFunctionKey<'ast>,
|
pub function_key: DeclarationFunctionKey<'ast>,
|
||||||
|
@ -870,11 +891,7 @@ pub enum FieldElementExpression<'ast, T> {
|
||||||
Box<FieldElementExpression<'ast, T>>,
|
Box<FieldElementExpression<'ast, T>>,
|
||||||
Box<UExpression<'ast, T>>,
|
Box<UExpression<'ast, T>>,
|
||||||
),
|
),
|
||||||
IfElse(
|
IfElse(IfElseExpression<'ast, T, Self>),
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<FieldElementExpression<'ast, T>>,
|
|
||||||
Box<FieldElementExpression<'ast, T>>,
|
|
||||||
),
|
|
||||||
Neg(Box<FieldElementExpression<'ast, T>>),
|
Neg(Box<FieldElementExpression<'ast, T>>),
|
||||||
Pos(Box<FieldElementExpression<'ast, T>>),
|
Pos(Box<FieldElementExpression<'ast, T>>),
|
||||||
FunctionCall(FunctionCallExpression<'ast, T, Self>),
|
FunctionCall(FunctionCallExpression<'ast, T, Self>),
|
||||||
|
@ -974,11 +991,7 @@ pub enum BooleanExpression<'ast, T> {
|
||||||
Box<BooleanExpression<'ast, T>>,
|
Box<BooleanExpression<'ast, T>>,
|
||||||
),
|
),
|
||||||
Not(Box<BooleanExpression<'ast, T>>),
|
Not(Box<BooleanExpression<'ast, T>>),
|
||||||
IfElse(
|
IfElse(IfElseExpression<'ast, T, Self>),
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
),
|
|
||||||
Member(MemberExpression<'ast, T, Self>),
|
Member(MemberExpression<'ast, T, Self>),
|
||||||
FunctionCall(FunctionCallExpression<'ast, T, Self>),
|
FunctionCall(FunctionCallExpression<'ast, T, Self>),
|
||||||
Select(SelectExpression<'ast, T, Self>),
|
Select(SelectExpression<'ast, T, Self>),
|
||||||
|
@ -1085,11 +1098,7 @@ pub enum ArrayExpressionInner<'ast, T> {
|
||||||
Identifier(Identifier<'ast>),
|
Identifier(Identifier<'ast>),
|
||||||
Value(ArrayValue<'ast, T>),
|
Value(ArrayValue<'ast, T>),
|
||||||
FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>),
|
FunctionCall(FunctionCallExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||||
IfElse(
|
IfElse(IfElseExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<ArrayExpression<'ast, T>>,
|
|
||||||
Box<ArrayExpression<'ast, T>>,
|
|
||||||
),
|
|
||||||
Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>),
|
Member(MemberExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||||
Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>),
|
Select(SelectExpression<'ast, T, ArrayExpression<'ast, T>>),
|
||||||
Slice(
|
Slice(
|
||||||
|
@ -1190,11 +1199,7 @@ pub enum StructExpressionInner<'ast, T> {
|
||||||
Identifier(Identifier<'ast>),
|
Identifier(Identifier<'ast>),
|
||||||
Value(Vec<TypedExpression<'ast, T>>),
|
Value(Vec<TypedExpression<'ast, T>>),
|
||||||
FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>),
|
FunctionCall(FunctionCallExpression<'ast, T, StructExpression<'ast, T>>),
|
||||||
IfElse(
|
IfElse(IfElseExpression<'ast, T, StructExpression<'ast, T>>),
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<StructExpression<'ast, T>>,
|
|
||||||
Box<StructExpression<'ast, T>>,
|
|
||||||
),
|
|
||||||
Member(MemberExpression<'ast, T, StructExpression<'ast, T>>),
|
Member(MemberExpression<'ast, T, StructExpression<'ast, T>>),
|
||||||
Select(SelectExpression<'ast, T, StructExpression<'ast, T>>),
|
Select(SelectExpression<'ast, T, StructExpression<'ast, T>>),
|
||||||
}
|
}
|
||||||
|
@ -1336,13 +1341,7 @@ impl<'ast, T: fmt::Display> fmt::Display for FieldElementExpression<'ast, T> {
|
||||||
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
|
FieldElementExpression::Pow(ref lhs, ref rhs) => write!(f, "{}**{}", lhs, rhs),
|
||||||
FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e),
|
FieldElementExpression::Neg(ref e) => write!(f, "(-{})", e),
|
||||||
FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e),
|
FieldElementExpression::Pos(ref e) => write!(f, "(+{})", e),
|
||||||
FieldElementExpression::IfElse(ref condition, ref consequent, ref alternative) => {
|
FieldElementExpression::IfElse(ref c) => write!(f, "{}", c),
|
||||||
write!(
|
|
||||||
f,
|
|
||||||
"if {} then {} else {} fi",
|
|
||||||
condition, consequent, alternative
|
|
||||||
)
|
|
||||||
}
|
|
||||||
FieldElementExpression::FunctionCall(ref function_call) => {
|
FieldElementExpression::FunctionCall(ref function_call) => {
|
||||||
write!(f, "{}", function_call)
|
write!(f, "{}", function_call)
|
||||||
}
|
}
|
||||||
|
@ -1376,11 +1375,7 @@ impl<'ast, T: fmt::Display> fmt::Display for UExpression<'ast, T> {
|
||||||
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
|
UExpressionInner::Pos(ref e) => write!(f, "(+{})", e),
|
||||||
UExpressionInner::Select(ref select) => write!(f, "{}", select),
|
UExpressionInner::Select(ref select) => write!(f, "{}", select),
|
||||||
UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
UExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||||
UExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
UExpressionInner::IfElse(ref c) => write!(f, "{}", c),
|
||||||
f,
|
|
||||||
"if {} then {} else {} fi",
|
|
||||||
condition, consequent, alternative
|
|
||||||
),
|
|
||||||
UExpressionInner::Member(ref m) => write!(f, "{}", m),
|
UExpressionInner::Member(ref m) => write!(f, "{}", m),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1409,11 +1404,7 @@ impl<'ast, T: fmt::Display> fmt::Display for BooleanExpression<'ast, T> {
|
||||||
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
BooleanExpression::Not(ref exp) => write!(f, "!{}", exp),
|
||||||
BooleanExpression::Value(b) => write!(f, "{}", b),
|
BooleanExpression::Value(b) => write!(f, "{}", b),
|
||||||
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
BooleanExpression::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||||
BooleanExpression::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
BooleanExpression::IfElse(ref c) => write!(f, "{}", c),
|
||||||
f,
|
|
||||||
"if {} then {} else {} fi",
|
|
||||||
condition, consequent, alternative
|
|
||||||
),
|
|
||||||
BooleanExpression::Member(ref m) => write!(f, "{}", m),
|
BooleanExpression::Member(ref m) => write!(f, "{}", m),
|
||||||
BooleanExpression::Select(ref select) => write!(f, "{}", select),
|
BooleanExpression::Select(ref select) => write!(f, "{}", select),
|
||||||
}
|
}
|
||||||
|
@ -1435,11 +1426,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ArrayExpressionInner<'ast, T> {
|
||||||
.join(", ")
|
.join(", ")
|
||||||
),
|
),
|
||||||
ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
ArrayExpressionInner::FunctionCall(ref function_call) => write!(f, "{}", function_call),
|
||||||
ArrayExpressionInner::IfElse(ref condition, ref consequent, ref alternative) => write!(
|
ArrayExpressionInner::IfElse(ref c) => write!(f, "{}", c),
|
||||||
f,
|
|
||||||
"if {} then {} else {} fi",
|
|
||||||
condition, consequent, alternative
|
|
||||||
),
|
|
||||||
ArrayExpressionInner::Member(ref m) => write!(f, "{}", m),
|
ArrayExpressionInner::Member(ref m) => write!(f, "{}", m),
|
||||||
ArrayExpressionInner::Select(ref select) => write!(f, "{}", select),
|
ArrayExpressionInner::Select(ref select) => write!(f, "{}", select),
|
||||||
ArrayExpressionInner::Slice(ref a, ref from, ref to) => {
|
ArrayExpressionInner::Slice(ref a, ref from, ref to) => {
|
||||||
|
@ -1619,6 +1606,11 @@ pub enum MemberOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||||
Expression(E::Inner),
|
Expression(E::Inner),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum IfElseOrExpression<'ast, T, E: Expr<'ast, T>> {
|
||||||
|
IfElse(IfElseExpression<'ast, T, E>),
|
||||||
|
Expression(E::Inner),
|
||||||
|
}
|
||||||
|
|
||||||
pub trait IfElse<'ast, T> {
|
pub trait IfElse<'ast, T> {
|
||||||
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
|
fn if_else(condition: BooleanExpression<'ast, T>, consequence: Self, alternative: Self)
|
||||||
-> Self;
|
-> Self;
|
||||||
|
@ -1630,7 +1622,7 @@ impl<'ast, T> IfElse<'ast, T> for FieldElementExpression<'ast, T> {
|
||||||
consequence: Self,
|
consequence: Self,
|
||||||
alternative: Self,
|
alternative: Self,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
FieldElementExpression::IfElse(box condition, box consequence, box alternative)
|
FieldElementExpression::IfElse(IfElseExpression::new(condition, consequence, alternative))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1640,7 +1632,7 @@ impl<'ast, T> IfElse<'ast, T> for IntExpression<'ast, T> {
|
||||||
consequence: Self,
|
consequence: Self,
|
||||||
alternative: Self,
|
alternative: Self,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
IntExpression::IfElse(box condition, box consequence, box alternative)
|
IntExpression::IfElse(IfElseExpression::new(condition, consequence, alternative))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1650,7 +1642,7 @@ impl<'ast, T> IfElse<'ast, T> for BooleanExpression<'ast, T> {
|
||||||
consequence: Self,
|
consequence: Self,
|
||||||
alternative: Self,
|
alternative: Self,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
BooleanExpression::IfElse(box condition, box consequence, box alternative)
|
BooleanExpression::IfElse(IfElseExpression::new(condition, consequence, alternative))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1662,7 +1654,8 @@ impl<'ast, T> IfElse<'ast, T> for UExpression<'ast, T> {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let bitwidth = consequence.bitwidth;
|
let bitwidth = consequence.bitwidth;
|
||||||
|
|
||||||
UExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(bitwidth)
|
UExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative))
|
||||||
|
.annotate(bitwidth)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1674,7 +1667,7 @@ impl<'ast, T: Clone> IfElse<'ast, T> for ArrayExpression<'ast, T> {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let ty = consequence.inner_type().clone();
|
let ty = consequence.inner_type().clone();
|
||||||
let size = consequence.size();
|
let size = consequence.size();
|
||||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative)
|
ArrayExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative))
|
||||||
.annotate(ty, size)
|
.annotate(ty, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1686,7 +1679,8 @@ impl<'ast, T: Clone> IfElse<'ast, T> for StructExpression<'ast, T> {
|
||||||
alternative: Self,
|
alternative: Self,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let ty = consequence.ty().clone();
|
let ty = consequence.ty().clone();
|
||||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative).annotate(ty)
|
StructExpressionInner::IfElse(IfElseExpression::new(condition, consequence, alternative))
|
||||||
|
.annotate(ty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -142,6 +142,16 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
|
||||||
fold_types(self, tys)
|
fold_types(self, tys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_if_else_expression<
|
||||||
|
E: Expr<'ast, T> + PartialEq + IfElse<'ast, T> + ResultFold<'ast, T>,
|
||||||
|
>(
|
||||||
|
&mut self,
|
||||||
|
ty: &E::Ty,
|
||||||
|
e: IfElseExpression<'ast, T, E>,
|
||||||
|
) -> Result<IfElseOrExpression<'ast, T, E>, Self::Error> {
|
||||||
|
fold_if_else_expression(self, ty, e)
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_block_expression<E: ResultFold<'ast, T>>(
|
fn fold_block_expression<E: ResultFold<'ast, T>>(
|
||||||
&mut self,
|
&mut self,
|
||||||
block: BlockExpression<'ast, T, E>,
|
block: BlockExpression<'ast, T, E>,
|
||||||
|
@ -415,13 +425,10 @@ pub fn fold_array_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
FunctionCallOrExpression::Expression(u) => u,
|
FunctionCallOrExpression::Expression(u) => u,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
ArrayExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? {
|
||||||
ArrayExpressionInner::IfElse(
|
IfElseOrExpression::IfElse(c) => ArrayExpressionInner::IfElse(c),
|
||||||
box f.fold_boolean_expression(condition)?,
|
IfElseOrExpression::Expression(u) => u,
|
||||||
box f.fold_array_expression(consequence)?,
|
},
|
||||||
box f.fold_array_expression(alternative)?,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
ArrayExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? {
|
ArrayExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? {
|
||||||
MemberOrExpression::Member(m) => ArrayExpressionInner::Member(m),
|
MemberOrExpression::Member(m) => ArrayExpressionInner::Member(m),
|
||||||
MemberOrExpression::Expression(u) => u,
|
MemberOrExpression::Expression(u) => u,
|
||||||
|
@ -469,13 +476,10 @@ pub fn fold_struct_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
FunctionCallOrExpression::Expression(u) => u,
|
FunctionCallOrExpression::Expression(u) => u,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
StructExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
StructExpressionInner::IfElse(c) => match f.fold_if_else_expression(ty, c)? {
|
||||||
StructExpressionInner::IfElse(
|
IfElseOrExpression::IfElse(c) => StructExpressionInner::IfElse(c),
|
||||||
box f.fold_boolean_expression(condition)?,
|
IfElseOrExpression::Expression(u) => u,
|
||||||
box f.fold_struct_expression(consequence)?,
|
},
|
||||||
box f.fold_struct_expression(alternative)?,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
StructExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? {
|
StructExpressionInner::Member(m) => match f.fold_member_expression(ty, m)? {
|
||||||
MemberOrExpression::Member(m) => StructExpressionInner::Member(m),
|
MemberOrExpression::Member(m) => StructExpressionInner::Member(m),
|
||||||
MemberOrExpression::Expression(u) => u,
|
MemberOrExpression::Expression(u) => u,
|
||||||
|
@ -535,11 +539,11 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
|
|
||||||
FieldElementExpression::Pos(box e)
|
FieldElementExpression::Pos(box e)
|
||||||
}
|
}
|
||||||
FieldElementExpression::IfElse(box cond, box cons, box alt) => {
|
FieldElementExpression::IfElse(c) => {
|
||||||
let cond = f.fold_boolean_expression(cond)?;
|
match f.fold_if_else_expression(&Type::FieldElement, c)? {
|
||||||
let cons = f.fold_field_expression(cons)?;
|
IfElseOrExpression::IfElse(c) => FieldElementExpression::IfElse(c),
|
||||||
let alt = f.fold_field_expression(alt)?;
|
IfElseOrExpression::Expression(u) => u,
|
||||||
FieldElementExpression::IfElse(box cond, box cons, box alt)
|
}
|
||||||
}
|
}
|
||||||
FieldElementExpression::FunctionCall(function_call) => {
|
FieldElementExpression::FunctionCall(function_call) => {
|
||||||
match f.fold_function_call_expression(&Type::FieldElement, function_call)? {
|
match f.fold_function_call_expression(&Type::FieldElement, function_call)? {
|
||||||
|
@ -589,6 +593,27 @@ pub fn fold_block_expression<'ast, T: Field, E: ResultFold<'ast, T>, F: ResultFo
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn fold_if_else_expression<
|
||||||
|
'ast,
|
||||||
|
T: Field,
|
||||||
|
E: Expr<'ast, T>
|
||||||
|
+ IfElse<'ast, T>
|
||||||
|
+ PartialEq
|
||||||
|
+ ResultFold<'ast, T>
|
||||||
|
+ From<TypedExpression<'ast, T>>,
|
||||||
|
F: ResultFolder<'ast, T>,
|
||||||
|
>(
|
||||||
|
f: &mut F,
|
||||||
|
_: &E::Ty,
|
||||||
|
e: IfElseExpression<'ast, T, E>,
|
||||||
|
) -> Result<IfElseOrExpression<'ast, T, E>, F::Error> {
|
||||||
|
Ok(IfElseOrExpression::IfElse(IfElseExpression::new(
|
||||||
|
f.fold_boolean_expression(*e.condition)?,
|
||||||
|
e.consequence.fold(f)?,
|
||||||
|
e.alternative.fold(f)?,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn fold_member_expression<
|
pub fn fold_member_expression<
|
||||||
'ast,
|
'ast,
|
||||||
T: Field,
|
T: Field,
|
||||||
|
@ -739,12 +764,10 @@ pub fn fold_boolean_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
FunctionCallOrExpression::Expression(u) => u,
|
FunctionCallOrExpression::Expression(u) => u,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BooleanExpression::IfElse(box cond, box cons, box alt) => {
|
BooleanExpression::IfElse(c) => match f.fold_if_else_expression(&Type::Boolean, c)? {
|
||||||
let cond = f.fold_boolean_expression(cond)?;
|
IfElseOrExpression::IfElse(c) => BooleanExpression::IfElse(c),
|
||||||
let cons = f.fold_boolean_expression(cons)?;
|
IfElseOrExpression::Expression(u) => u,
|
||||||
let alt = f.fold_boolean_expression(alt)?;
|
},
|
||||||
BooleanExpression::IfElse(box cond, box cons, box alt)
|
|
||||||
}
|
|
||||||
BooleanExpression::Select(select) => {
|
BooleanExpression::Select(select) => {
|
||||||
match f.fold_select_expression(&Type::Boolean, select)? {
|
match f.fold_select_expression(&Type::Boolean, select)? {
|
||||||
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
|
SelectOrExpression::Select(s) => BooleanExpression::Select(s),
|
||||||
|
@ -869,12 +892,10 @@ pub fn fold_uint_expression_inner<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
||||||
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
SelectOrExpression::Select(s) => UExpressionInner::Select(s),
|
||||||
SelectOrExpression::Expression(u) => u,
|
SelectOrExpression::Expression(u) => u,
|
||||||
},
|
},
|
||||||
UExpressionInner::IfElse(box cond, box cons, box alt) => {
|
UExpressionInner::IfElse(c) => match f.fold_if_else_expression(&ty, c)? {
|
||||||
let cond = f.fold_boolean_expression(cond)?;
|
IfElseOrExpression::IfElse(c) => UExpressionInner::IfElse(c),
|
||||||
let cons = f.fold_uint_expression(cons)?;
|
IfElseOrExpression::Expression(u) => u,
|
||||||
let alt = f.fold_uint_expression(alt)?;
|
},
|
||||||
UExpressionInner::IfElse(box cond, box cons, box alt)
|
|
||||||
}
|
|
||||||
UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m)? {
|
UExpressionInner::Member(m) => match f.fold_member_expression(&ty, m)? {
|
||||||
MemberOrExpression::Member(m) => UExpressionInner::Member(m),
|
MemberOrExpression::Member(m) => UExpressionInner::Member(m),
|
||||||
MemberOrExpression::Expression(u) => u,
|
MemberOrExpression::Expression(u) => u,
|
||||||
|
|
|
@ -193,11 +193,7 @@ pub enum UExpressionInner<'ast, T> {
|
||||||
FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>),
|
FunctionCall(FunctionCallExpression<'ast, T, UExpression<'ast, T>>),
|
||||||
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
LeftShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||||
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
RightShift(Box<UExpression<'ast, T>>, Box<UExpression<'ast, T>>),
|
||||||
IfElse(
|
IfElse(IfElseExpression<'ast, T, UExpression<'ast, T>>),
|
||||||
Box<BooleanExpression<'ast, T>>,
|
|
||||||
Box<UExpression<'ast, T>>,
|
|
||||||
Box<UExpression<'ast, T>>,
|
|
||||||
),
|
|
||||||
Member(MemberExpression<'ast, T, UExpression<'ast, T>>),
|
Member(MemberExpression<'ast, T, UExpression<'ast, T>>),
|
||||||
Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
|
Select(SelectExpression<'ast, T, UExpression<'ast, T>>),
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
def throwing_bound<N>(u32 x) -> u32:
|
||||||
|
assert(x == N)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
// this compiles: the conditional, even though it can throw, has a constant compile-time value of `1`
|
||||||
|
// However, the assertions are still checked at runtime, which leads to panics without branch isolation.
|
||||||
|
def main(u32 x):
|
||||||
|
for u32 i in 0..if x == 0 then throwing_bound::<0>(x) else throwing_bound::<1>(x) fi do
|
||||||
|
endfor
|
||||||
|
return
|
|
@ -0,0 +1,51 @@
|
||||||
|
{
|
||||||
|
"entry_point": "./tests/tests/panics/conditional_bound_throw.zok",
|
||||||
|
"curves": ["Bn128"],
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"values": [
|
||||||
|
"0"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"Err": {
|
||||||
|
"UnsatisfiedConstraint": {
|
||||||
|
"left": "0",
|
||||||
|
"right": "1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"values": [
|
||||||
|
"1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"Err": {
|
||||||
|
"UnsatisfiedConstraint": {
|
||||||
|
"left": "1",
|
||||||
|
"right": "0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"values": [
|
||||||
|
"2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"Err": {
|
||||||
|
"UnsatisfiedConstraint": {
|
||||||
|
"left": "2",
|
||||||
|
"right": "0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
Loading…
Reference in a new issue